Compare commits

..

4 Commits

Author SHA1 Message Date
e3d00beddd Fix triu_/tril_ overlap handling 2025-10-21 07:54:24 -07:00
21131a2444 Revert "[ROCm][CI] Update rocm.yml workflow to use 1 GPU ARC runners (#165481)"
This reverts commit ffa90d46e61650834d5f926008f48f50c6a7e87a.

Reverted https://github.com/pytorch/pytorch/pull/165481 on behalf of https://github.com/jeffdaily due to timeouts after merge ([comment](https://github.com/pytorch/pytorch/pull/165481#issuecomment-3426898171))
2025-10-21 14:15:55 +00:00
1009790ad8 [pytree][dynamo] trace on native optree functions for community pytree support (#165860)
Resolves #164972

- #164972

All `torch.utils._cxx_pytree` functions are based on `optree` functions with hardcoded `none_is_leaf=True` and `namespace="torch"`. This PR changes the polyfills to generic `optree` functions with those arguments unhardcoded. This means `torch.utils._cxx_pytree` functions are still traceable while the community `optree` usages can get dynamo support additionally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165860
Approved by: https://github.com/Lucaskabela
2025-10-21 14:13:08 +00:00
410e6a4321 Better error handling in torch/csrc/jit/frontend/* (#165213)
Refactor error handling by using TORCH_CHECK for improved clarity in constants and scope management in some files in torch/csrc/jit/frontend/*

Fixes some parts of ISSUE https://github.com/pytorch/pytorch/issues/148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165213
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-21 13:54:59 +00:00
15 changed files with 267 additions and 125 deletions

View File

@ -54,17 +54,12 @@ self-hosted-runner:
- windows-11-arm64
- windows-11-arm64-preview
# Organization-wide AMD-hosted runners
# MI2xx non-ARC runners
# MI2xx runners
- linux.rocm.gpu
- linux.rocm.gpu.mi250
- linux.rocm.gpu.2
- linux.rocm.gpu.4
- linux.rocm.gpu.mi250
- linux.rocm.gpu.gfx1100
# MI2xx ARC runners
- linux.rocm.gpu.mi250.1
- linux.rocm.gpu.mi250.2
- linux.rocm.gpu.mi250.4
# gfx942 ARC runners
# gfx942 runners
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4

View File

@ -36,12 +36,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi250.1" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
]}
secrets: inherit

View File

@ -141,6 +141,8 @@ void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
return;
}
checkTrilTriuMemoryOverlap(result, self);
bool inplace_op = self.is_same(result);
bool inplace_update = false;

View File

@ -1,3 +1,4 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/LinearAlgebraUtils.h>
@ -54,4 +55,13 @@ static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor
return std::make_tuple(true, tensor);
}
static inline void checkTrilTriuMemoryOverlap(const Tensor& result, const Tensor& self) {
if (result.is_same(self)) {
at::assert_no_internal_overlap(result);
} else {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
}
}
} // namespace at::native

View File

@ -5,6 +5,7 @@
#include <ATen/Dispatch.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TriangularOpsUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -110,6 +111,8 @@ __global__ void triu_tril_kernel(
template <bool upper>
void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
checkTrilTriuMemoryOverlap(result, self);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
at::ScalarType::ComplexHalf,
at::ScalarType::Half,

View File

@ -424,7 +424,7 @@ from user code:
@torch.compile(backend="eager")
def fn(x):
d = {"a": 1}
optree.tree_flatten(d)
optree.tree_flatten_with_path(d)
return torch.sin(x)
fn(torch.randn(4))
@ -434,10 +434,10 @@ from user code:
first_graph_break,
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten.
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: PyCapsule.flatten, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)

View File

@ -110,6 +110,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import torch.utils._cxx_pytree as cxx_pytree
pytree_modules["cxx"] = cxx_pytree
pytree_modules["native_optree"] = cxx_pytree.optree
else:
cxx_pytree = None
@ -12862,6 +12863,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12876,6 +12880,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12893,6 +12900,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12910,6 +12920,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -12931,6 +12944,9 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
def fn(xs):
flat_xs, spec = pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
if pytree.__name__ == "optree":
# The treespec argument comes first in OpTree / JAX PyTree
return pytree.tree_unflatten(spec, res)
return pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
@ -13032,7 +13048,13 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
torch.ones(3, 2),
1,
]
new_tree = pytree.tree_unflatten(new_leaves, treespec)
if pytree.__name__ == "optree":
# `None` is a internal node rather than leaf in default OpTree / JAX PyTree
new_leaves.pop()
# The treespec argument comes first in OpTree / JAX PyTree
new_tree = pytree.tree_unflatten(treespec, new_leaves)
else:
new_tree = pytree.tree_unflatten(new_leaves, treespec)
return leaves, new_tree
x = torch.randn(3, 2)
@ -13087,6 +13109,10 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
@parametrize_pytree_module
def test_pytree_tree_map_only(self, pytree):
if not callable(getattr(pytree, "tree_map_only", None)):
# OpTree and JAX PyTree do not have `tree_map_only`
return
def fn(xs):
def mapper(x):
return x.clone()

View File

@ -9986,6 +9986,20 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(result_triu_min, expected_triu_min)
self.assertEqual(result_tril_min, expected_tril_min)
@dtypes(torch.float)
def test_triu_tril_inplace_memory_overlap(self, device, dtype):
base = torch.rand((), dtype=dtype, device=device)
expanded = base.expand(3, 3)
msg = (
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation."
)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.triu_(1)
with self.assertRaisesRegex(RuntimeError, msg):
expanded.tril_(-1)
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype):

View File

@ -19,6 +19,7 @@ from torch.testing._internal.common_cuda import (
SM80OrLater,
SM90OrLater,
SM100OrLater,
xfailIfSM120OrLater,
_get_torch_cuda_version,
)
from torch.testing._internal.common_device_type import (
@ -325,6 +326,7 @@ class TestMatmulCuda(InductorTestCase):
self.assertEqual(agrad, a.grad)
self.assertEqual(bgrad, b.grad)
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@ -362,6 +364,7 @@ class TestMatmulCuda(InductorTestCase):
start = offs_cpu[i]
self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out)
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@ -417,6 +420,7 @@ class TestMatmulCuda(InductorTestCase):
self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist)
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@ -450,6 +454,7 @@ class TestMatmulCuda(InductorTestCase):
out.backward(gO)
self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out)
@xfailIfSM120OrLater
@unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING
from typing_extensions import TypeIs
import torch.utils._pytree as python_pytree
@ -28,7 +28,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
import optree
import optree._C
import torch.utils._cxx_pytree as cxx_pytree
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
if TYPE_CHECKING:
from torch.utils._cxx_pytree import PyTree
@ -64,45 +64,69 @@ if python_pytree._cxx_pytree_dynamo_traceable:
del __func
del __name
@substitute_in_graph(cxx_pytree.tree_is_leaf, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True)
def tree_is_leaf(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> bool:
if tree is None or (is_leaf is not None and is_leaf(tree)):
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
return True
if optree.register_pytree_node.get(type(tree), namespace="torch") is None: # type: ignore[attr-defined]
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
return True
return False
@substitute_in_graph(cxx_pytree.tree_iter, can_constant_fold_through=False)
@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False)
def tree_iter(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> Iterable[Any]:
stack = [tree]
while stack:
node = stack.pop()
if tree_is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
yield node
continue
children, *_ = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
none_is_leaf=none_is_leaf,
namespace=namespace,
)
stack.extend(reversed(children))
__all__ += ["tree_iter"]
@substitute_in_graph(cxx_pytree.tree_leaves, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True)
def tree_leaves(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> list[Any]:
return list(tree_iter(tree, is_leaf=is_leaf))
return list(
tree_iter(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
)
__all__ += ["tree_leaves"]
@ -127,12 +151,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
_metadata: Any
_entries: tuple[Any, ...]
_unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None
none_is_leaf: bool
namespace: str
num_nodes: int = field(init=False)
num_leaves: int = field(init=False)
num_children: int = field(init=False)
none_is_leaf: Literal[True] = field(init=False)
namespace: Literal["torch"] = field(init=False)
def __post_init__(self) -> None:
if self._type is None:
@ -152,8 +176,6 @@ if python_pytree._cxx_pytree_dynamo_traceable:
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
object.__setattr__(self, "none_is_leaf", True)
object.__setattr__(self, "namespace", "torch")
def __repr__(self) -> str:
def helper(treespec: PyTreeSpec) -> str:
@ -168,6 +190,7 @@ if python_pytree._cxx_pytree_dynamo_traceable:
]
if (
treespec.type in BUILTIN_TYPES
or (treespec.type is type(None) and not self.none_is_leaf)
or optree.is_namedtuple_class(treespec.type)
or optree.is_structseq_class(treespec.type)
):
@ -181,9 +204,12 @@ if python_pytree._cxx_pytree_dynamo_traceable:
f"[{', '.join(children_representations)}])"
)
return (
f"PyTreeSpec({helper(self)}, NoneIsLeaf, namespace={self.namespace!r})"
)
inner = [
str(helper(self)),
*(["NoneIsLeaf"] if self.none_is_leaf else []),
f"namespace={self.namespace!r}",
]
return f"PyTreeSpec({', '.join(inner)})"
def __len__(self) -> int:
return self.num_leaves
@ -228,8 +254,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if len(children) != treespec.num_children:
raise ValueError(
@ -277,8 +303,8 @@ if python_pytree._cxx_pytree_dynamo_traceable:
# node_type is treespec.type
children, metadata, *_ = optree.tree_flatten_one_level(
node,
none_is_leaf=True,
namespace="torch",
none_is_leaf=self.none_is_leaf,
namespace=self.namespace,
)
if (
node_type
@ -320,25 +346,40 @@ if python_pytree._cxx_pytree_dynamo_traceable:
assert callable(self._unflatten_func)
return self._unflatten_func(self._metadata, subtrees)
_LEAF_SPEC = PyTreeSpec((), None, None, (), None)
def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]:
return isinstance(obj, PyTreeSpec)
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_flatten,
optree.tree_flatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_flatten(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> tuple[list[Any], PyTreeSpec]:
def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec:
if tree_is_leaf(node, is_leaf=is_leaf):
if tree_is_leaf(
node,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
):
leaves.append(node)
return _LEAF_SPEC
return PyTreeSpec(
(),
None,
None,
(),
None,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
(
children,
@ -348,13 +389,21 @@ if python_pytree._cxx_pytree_dynamo_traceable:
) = optree.tree_flatten_one_level(
node,
is_leaf=is_leaf,
none_is_leaf=True,
namespace="torch",
none_is_leaf=none_is_leaf,
namespace=namespace,
)
# Recursively flatten the children
subspecs = tuple(helper(child, leaves) for child in children)
return PyTreeSpec(subspecs, type(node), metadata, entries, unflatten_func) # type: ignore[arg-type]
return PyTreeSpec(
subspecs,
type(node),
metadata,
entries,
unflatten_func,
none_is_leaf=none_is_leaf,
namespace=namespace,
) # type: ignore[arg-type]
leaves: list[Any] = []
treespec = helper(tree, leaves)
@ -363,26 +412,35 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_flatten"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_structure,
optree.tree_structure,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_structure(
tree: PyTree,
/,
is_leaf: Callable[[PyTree], bool] | None = None,
*,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTreeSpec:
return tree_flatten(tree, is_leaf=is_leaf)[1] # type: ignore[return-value]
return tree_flatten( # type: ignore[return-value]
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)[1]
__all__ += ["tree_structure"]
@substitute_in_graph( # type: ignore[arg-type]
cxx_pytree.tree_unflatten,
optree.tree_unflatten,
# We need to disable constant folding here because we want the function to reference the
# PyTreeSpec class defined above, not the one in the C++ module.
can_constant_fold_through=False,
)
def tree_unflatten(leaves: Iterable[Any], treespec: PyTreeSpec) -> PyTree:
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
@ -392,29 +450,57 @@ if python_pytree._cxx_pytree_dynamo_traceable:
__all__ += ["tree_unflatten"]
@substitute_in_graph(cxx_pytree.tree_map, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True)
def tree_map(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(map(func, *flat_args))
__all__ += ["tree_map"]
@substitute_in_graph(cxx_pytree.tree_map_, can_constant_fold_through=True)
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True)
def tree_map_(
func: Callable[..., Any],
tree: PyTree,
/,
*rests: PyTree,
is_leaf: Callable[[PyTree], bool] | None = None,
none_is_leaf: bool = False,
namespace: str = "",
) -> PyTree:
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
leaves, treespec = tree_flatten(
tree,
is_leaf=is_leaf,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree
__all__ += ["tree_map_"]
_none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr]
@substitute_in_graph( # type: ignore[arg-type]
_none_unflatten,
can_constant_fold_through=True,
skip_signature_check=True,
)
def none_unflatten(_: None, children: Iterable[Any], /) -> None:
if len(list(children)) != 0:
raise ValueError("Expected no children.")
return None

View File

@ -3259,7 +3259,7 @@ struct to_ir {
case TK_IN:
return aten::__contains__;
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
TORCH_CHECK(false, "unknown kind ", kind);
}
}
@ -3306,7 +3306,7 @@ struct to_ir {
case TK_RSHIFT:
return "__rshift__";
default:
throw std::runtime_error("unknown kind " + std::to_string(kind));
TORCH_CHECK(false, "unknown kind ", kind);
}
}
@ -4120,8 +4120,7 @@ struct to_ir {
} else if (kind == aten::ge) {
return aten::le;
}
throw std::runtime_error(
"reverseComparision: unsupported NodeKind. File a bug");
TORCH_CHECK(false, "reverseComparision: unsupported NodeKind. File a bug");
}
// any expression that can produce a SugaredValue is handled here

View File

@ -94,7 +94,7 @@ C10_EXPORT std::string kindToString(int kind) {
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("Unknown kind: " + std::to_string(kind));
TORCH_CHECK(false, "Unknown kind: ", kind);
}
}

View File

@ -167,12 +167,12 @@ Value* TracingState::getValue(const IValue& var) {
// Didn't find it. Bake in a constant
if (ten.requires_grad()) {
pauseTracing();
std::ostringstream oss;
oss << "Cannot insert a Tensor that requires grad as a constant. "
<< "Consider making it a parameter or input, or detaching the gradient\n"
<< "Tensor:\n"
<< ten;
throw std::runtime_error(oss.str());
TORCH_CHECK(
false,
"Cannot insert a Tensor that requires grad as a constant. ",
"Consider making it a parameter or input, or detaching the gradient\n",
"Tensor:\n",
ten);
}
Value* constant = graph->insertConstant(ten);
@ -208,15 +208,19 @@ Value* TracingState::getValue(const IValue& var) {
}
}
std::ostringstream oss;
if (var.isFuture()) {
oss << "Tried to trace Future or Object that the tracer was not aware of.";
TORCH_CHECK(
false,
"Tried to trace Future or Object that the tracer was not aware of.");
} else {
oss << "Tried to trace " << var
<< " but it is not part of the active trace. Modules that are called during a trace"
<< " must be registered as submodules of the thing being traced.";
TORCH_CHECK(
false,
"Tried to trace ",
var,
" but it is not part of the active trace. Modules that are called during a trace",
" must be registered as submodules of the thing being traced.");
}
throw std::runtime_error(oss.str());
} else {
// If the values are non-tensors, we try to create constants
// and bake those constants into the traced graph
@ -225,11 +229,12 @@ Value* TracingState::getValue(const IValue& var) {
recordSourceLocation(constant.value()->node());
return *constant;
}
std::ostringstream os;
os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
<< "The below value could not be materialized as a constant:\n"
<< var;
throw std::runtime_error(os.str());
TORCH_CHECK(
false,
"Tracer cannot get value trace for type ",
var.tagKind(),
". The below value could not be materialized as a constant:\n",
var);
}
}
bool TracingState::hasValue(const IValue& var) const {
@ -252,15 +257,14 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
auto& value_map = getTracingState()->env_stack.back();
auto it = value_map.find(iv);
if (it == value_map.end()) {
std::ostringstream os;
os << "output " << i << " (" << var
<< ") of traced region did not have observable "
<< "data dependence with trace inputs; this probably indicates your "
"program "
<< "cannot be understood by the tracer.";
throw std::runtime_error(os.str());
}
TORCH_CHECK(
it != value_map.end(),
"output ",
i,
" (",
var,
") of traced region did not have observable data dependence with trace inputs; ",
"this probably indicates your program cannot be understood by the tracer.");
return it->second;
} else if (iv.isTensorList()) {
if (tracing_mode_strict) {
@ -281,11 +285,10 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
graph->insertNode(tuple_node);
return tuple_node->output();
} else if (iv.isGenericDict()) {
if (tracing_mode_strict) {
throw std::runtime_error(
"Encountering a dict at the output of the tracer" +
std::string(STRICT_TRACER_MSG));
}
TORCH_CHECK(
!tracing_mode_strict,
"Encountering a dict at the output of the tracer",
STRICT_TRACER_MSG);
auto dict = iv.toGenericDict();
TypePtr key_type = dict.keyType();
TypePtr value_type = dict.valueType();
@ -304,15 +307,15 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) {
}
}
}
if (!key_type_valid || !value_type_valid) {
std::ostringstream os;
os << "output " << i << " (" << dict << ") of traced region "
<< "cannot be understood by the tracer, only outputs matching"
<< "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
<< "can be a dictionary output of a traced function";
throw std::runtime_error(os.str());
}
TORCH_CHECK(
key_type_valid && value_type_valid,
"output ",
i,
" (",
dict,
") of traced region cannot be understood by the tracer, only outputs matching ",
"dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] ",
"can be a dictionary output of a traced function");
std::vector<Value*> keys;
std::vector<Value*> values;
for (const auto& entry : dict) {
@ -598,10 +601,11 @@ void TracingState::setValue(const IValue& v, Value* value) {
setValue(entry.value(), static_value);
}
} else {
std::ostringstream os;
os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
<< "Supported types are tensor, tensor list, and tuple of tensors.";
throw std::runtime_error(os.str());
TORCH_CHECK(
false,
"Tracer cannot set value trace for type ",
v.tagKind(),
". Supported types are tensor, tensor list, and tuple of tensors.");
}
}
@ -801,11 +805,10 @@ void addInputs(Node* n, const char* name, at::IntArrayRef value) {
recordSourceLocation(info[i]->node());
}
for (jit::Value* v : info) {
if (*v->type() != *jit::IntType::get()) {
throw std::runtime_error(
"Type mismatch in setposattr for IntArrayRef. Check that your program "
"is valid without tracing, and please file a bug report if it is.");
}
TORCH_CHECK(
*v->type() == *jit::IntType::get(),
"Type mismatch in setposattr for IntArrayRef. Check that your program "
"is valid without tracing, and please file a bug report if it is.");
}
n->addInput(
g->insertNode(g->createList(jit::IntType::get(), info))->output());

View File

@ -5,6 +5,7 @@
#include <unordered_map>
#include <vector>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/jit/frontend/lexer.h>
@ -37,10 +38,10 @@ struct Tree : c10::intrusive_ptr_target {
return true;
}
virtual const SourceRange& range() const {
throw std::runtime_error("is an Atom");
TORCH_CHECK(false, "is an Atom");
}
virtual const std::string& stringValue() const {
throw std::runtime_error("stringValue can only be called on TK_STRING");
TORCH_CHECK(false, "stringValue can only be called on TK_STRING");
}
virtual const TreeList& trees() const {
static const TreeList empty_trees = {};
@ -79,13 +80,16 @@ struct Tree : c10::intrusive_ptr_target {
int lineno,
size_t expected_subtrees,
bool allow_more) const {
if (kind() != k) {
std::stringstream ss;
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
<< "' but found '" << kindToString(kind()) << "'\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
}
TORCH_CHECK(
kind() == k,
filename,
":",
lineno,
": expecting kind '",
kindToString(k),
"' but found '",
kindToString(kind()),
"'\n");
if (trees().size() < expected_subtrees ||
(!allow_more && trees().size() != expected_subtrees)) {
std::stringstream ss;
@ -93,7 +97,7 @@ struct Tree : c10::intrusive_ptr_target {
<< expected_subtrees << " subtrees, but found only " << trees().size()
<< "\n";
range().highlight(ss);
throw std::runtime_error(ss.str());
TORCH_CHECK(false, ss.str());
}
}
~Tree() override = default;

View File

@ -367,11 +367,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
``treespec``.
"""
if not _is_pytreespec_instance(treespec):
raise TypeError(
f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of "
f"PyTreeSpec but got item of type {type(treespec)}."
)
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]