mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][14/16] fix typos in torch/ (torch/fx/) (#156604)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156604 Approved by: https://github.com/jingsh ghstack dependencies: #156318, #156320, #156602
This commit is contained in:
committed by
PyTorch MergeBot
parent
db259bd6b8
commit
11c07c848c
@ -1173,7 +1173,6 @@ exclude_patterns = [
|
||||
'test/distributed/**',
|
||||
'torch/**',
|
||||
'torch/_*/**',
|
||||
'torch/fx/**',
|
||||
'torch/distributed/tensor/**',
|
||||
'torch/utils/**',
|
||||
]
|
||||
|
@ -36,6 +36,7 @@ serder
|
||||
serdes
|
||||
statics
|
||||
strat
|
||||
supercede
|
||||
supercedes
|
||||
te
|
||||
WONT
|
||||
|
@ -52,7 +52,7 @@ demonstration of these components in action:
|
||||
|
||||
The **symbolic tracer** performs "symbolic execution" of the Python
|
||||
code. It feeds fake values, called Proxies, through the code. Operations
|
||||
on theses Proxies are recorded. More information about symbolic tracing
|
||||
on these Proxies are recorded. More information about symbolic tracing
|
||||
can be found in the :func:`symbolic_trace` and :class:`Tracer`
|
||||
documentation.
|
||||
|
||||
|
@ -127,7 +127,7 @@ class _LazyGraphModule(GraphModule):
|
||||
|
||||
forward = _lazy_forward
|
||||
|
||||
# TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
|
||||
# TODO: we should handle __reduce_deploy__ the same way as __reduce_package__,
|
||||
# or __reduce__ by calling _real_recompile. But I don't find a good way
|
||||
# to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
|
||||
# will be used in torch::deploy. So it's skipped for now.
|
||||
|
@ -42,7 +42,7 @@ def tree_flatten_spec(
|
||||
# I guess these exist for BC, FC reasons.
|
||||
# In general, we should be able to directly
|
||||
# use pytree tree flattener to flatten them,
|
||||
# as export serializes the pytree seperately.
|
||||
# as export serializes the pytree separately.
|
||||
# Will remove it in follow up PR.
|
||||
if spec.type in SUPPORTED_NODES:
|
||||
flatten_fn_spec = SUPPORTED_NODES[spec.type]
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Whether to disable showing progress on compilation passes
|
||||
# Need to add a new config otherwise wil get a circular import if dynamo config is imported here
|
||||
# Need to add a new config otherwise will get a circular import if dynamo config is imported here
|
||||
disable_progress = True
|
||||
|
||||
# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
|
||||
|
@ -82,7 +82,7 @@ def expand_to_tensor_dim(t, n):
|
||||
def broadcast_types(t1, t2):
|
||||
"""
|
||||
Applies broadcasting to both given types such that they
|
||||
become consistent with eachother and returns two new
|
||||
become consistent with each other and returns two new
|
||||
resulting types
|
||||
"""
|
||||
|
||||
@ -846,7 +846,7 @@ def flatten_refinement_rule(n: Node):
|
||||
@register_algebraic_expressions_inference_rule(Conv2d)
|
||||
def conv_rule(n: Node, module_instance):
|
||||
"""
|
||||
Represents the outout in terms of an algrbraic expression w.r.t
|
||||
Represents the output in terms of an algrbraic expression w.r.t
|
||||
the input when possible
|
||||
"""
|
||||
assert isinstance(n.args[0], Node)
|
||||
|
@ -164,7 +164,7 @@ class TGreatestUpperBound(Constraint):
|
||||
|
||||
def __init__(self, res, rhs1, rhs2):
|
||||
"""
|
||||
:param res: tensor variable that stores the result of the outout
|
||||
:param res: tensor variable that stores the result of the output
|
||||
:param rhs1: tensor or tensor variable
|
||||
:param rhs2: tensor or tensor variabke
|
||||
"""
|
||||
@ -407,7 +407,7 @@ class CalcConv(Constraint):
|
||||
"""
|
||||
:param conv_result: the convolution result
|
||||
:param input_var: input to convolution
|
||||
:param c_out: output chanel type
|
||||
:param c_out: output channel type
|
||||
:param kernel: kernel tuple
|
||||
"""
|
||||
self.conv_result = conv_result
|
||||
|
@ -823,7 +823,7 @@ def calc_last_two_dims(constraint, d: list[DVar]):
|
||||
[BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]
|
||||
)
|
||||
|
||||
# transform parameters into tuples incase they are not already
|
||||
# transform parameters into tuples in case they are not already
|
||||
padding = (
|
||||
(constraint.padding, constraint.padding)
|
||||
if isinstance(constraint.padding, int)
|
||||
|
@ -1415,7 +1415,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
kwargs = kwargs or {}
|
||||
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
||||
# It's for passing the export verifier which needs to verify the meta['val']
|
||||
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
|
||||
# TODO(tmanlaibaatar): we should systematically couple it with export verifier,
|
||||
# instead of hardcoding it here.
|
||||
# T203648563
|
||||
if func == torch.amp.autocast_mode._exit_autocast:
|
||||
@ -1432,7 +1432,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||
node.meta["val"] = None
|
||||
return node
|
||||
# Don't actually run the function! We just want to trace the calls
|
||||
# into a graph. We don't actualy want to change global autograd state.
|
||||
# into a graph. We don't actually want to change global autograd state.
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
@ -1333,7 +1333,7 @@ def compute_unbacked_bindings(
|
||||
# If old_s is not an unbacked_symbol,
|
||||
# we assume that the original unbacked symbol is replaced
|
||||
# by a backed symbol (old_s). This can happen
|
||||
# when this node reuses the orignal symbol (due to memoi)
|
||||
# when this node reuses the original symbol (due to memoi)
|
||||
# and the original symbol gets replaced by the backed symbol.
|
||||
# When this happens we just replace new_s by the old_s
|
||||
# because we know the value is the same.
|
||||
@ -2374,7 +2374,7 @@ def _maybe_evaluate_static_worker(
|
||||
|
||||
# Note:
|
||||
# Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers.
|
||||
# Sympy might give unexepected results when comparing an integer with a non-integer
|
||||
# Sympy might give unexpected results when comparing an integer with a non-integer
|
||||
# Therefore, we cast offset to int here.
|
||||
# For example:
|
||||
# shape_0 = sympy.Symbol("shape_0", positive=True, integer=True)
|
||||
@ -3382,7 +3382,7 @@ class DimConstraints:
|
||||
constraint_violation_error: object,
|
||||
forced_specializations: dict[str, str],
|
||||
) -> str:
|
||||
"""Format a message for constraint violation erros"""
|
||||
"""Format a message for constraint violation errors"""
|
||||
from torch.export.dynamic_shapes import _get_dim_name_mapping
|
||||
|
||||
if not self._dcp.source_name_to_debug_name:
|
||||
@ -3939,7 +3939,7 @@ class ShapeEnv:
|
||||
added_replacements[axiom.lhs] = axiom.rhs
|
||||
self.axioms.update(new_axioms)
|
||||
|
||||
# We need to freeze the ShapeEnv becuase any additional modification of
|
||||
# We need to freeze the ShapeEnv because any additional modification of
|
||||
# the ShapeEnv will cause unsoundness for subsequent specialization calls.
|
||||
self.frozen = True
|
||||
try:
|
||||
@ -4473,7 +4473,7 @@ class ShapeEnv:
|
||||
|
||||
# The order of checking the guards matters. In this specific example:
|
||||
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
||||
# we may have an unnessary shape speciliazation for y.
|
||||
# we may have an unnecessary shape speciliazation for y.
|
||||
def _maybe_specialize_sym_int_with_hint(
|
||||
self, maybe_sym: IntLikeType
|
||||
) -> IntLikeType:
|
||||
@ -5249,7 +5249,7 @@ class ShapeEnv:
|
||||
# calls on this new instance. Finally, it will check whether this new instance
|
||||
# has equal state.
|
||||
#
|
||||
# It's important that we do it in the begining of this function, since it modifies
|
||||
# It's important that we do it in the beginning of this function, since it modifies
|
||||
# self.dim_constraints through its execution. Changes that happen in this method
|
||||
# aren't interesting, since this is the function call we wish to reproduce at the
|
||||
# end. If we wish to simply reproduce ShapeEnv instances even after this call,
|
||||
@ -6246,7 +6246,7 @@ class ShapeEnv:
|
||||
|
||||
Use compute_hint == True if you are trying to compute a non-binding
|
||||
hint for the particular hint values of backed and unbacked SymInts,
|
||||
e.g., if s0 happens to be 3 this run, compute_hint will subsitute s0 with 3.
|
||||
e.g., if s0 happens to be 3 this run, compute_hint will substitute s0 with 3.
|
||||
"""
|
||||
|
||||
# axioms with compute hint NYE
|
||||
@ -6267,7 +6267,7 @@ class ShapeEnv:
|
||||
# A FloorDiv in implications could have became CleanDiv at this point, due to new facts
|
||||
# to the shapeEnv. This handles such issue but its not ideal. This is the only expression
|
||||
# simplification that depends on the global state of shape env.
|
||||
# TODO try to get rid of CleanDiv since it breaks the invariant thats simplifications of sympy
|
||||
# TODO try to get rid of CleanDiv since it breaks the invariant that's simplifications of sympy
|
||||
# expressions only depend on the expression itself.
|
||||
if k.has(FloorDiv):
|
||||
new_items.update({self.simplify(k): v})
|
||||
|
@ -5,9 +5,9 @@ from collections import OrderedDict
|
||||
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
|
||||
|
||||
|
||||
def raises(err, lamda):
|
||||
def raises(err, lamda): # codespell:ignore lamda
|
||||
try:
|
||||
lamda()
|
||||
lamda() # codespell:ignore lamda
|
||||
return False
|
||||
except err:
|
||||
return True
|
||||
|
@ -23,9 +23,9 @@ def transitive_get(key, d):
|
||||
return key
|
||||
|
||||
|
||||
def raises(err, lamda):
|
||||
def raises(err, lamda): # codespell:ignore lamda
|
||||
try:
|
||||
lamda()
|
||||
lamda() # codespell:ignore lamda
|
||||
return False
|
||||
except err:
|
||||
return True
|
||||
|
@ -651,7 +651,7 @@ from torch.fx.experimental import _config as config
|
||||
|
||||
|
||||
def translation_validation_enabled() -> bool:
|
||||
# Checks everytime this function is called, in case the Dynamo
|
||||
# Checks every time this function is called, in case the Dynamo
|
||||
# option is set, but Z3 is not installed.
|
||||
_assert_z3_installed_if_tv_set()
|
||||
return _HAS_Z3 and config.translation_validation
|
||||
|
@ -1005,7 +1005,7 @@ class Graph:
|
||||
|
||||
Returns:
|
||||
|
||||
Iteratable of nodes with the requested op and target.
|
||||
Iterable of nodes with the requested op and target.
|
||||
"""
|
||||
node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target)
|
||||
if sort:
|
||||
@ -1565,7 +1565,7 @@ class Graph:
|
||||
# To do this, we create a new namespace just for this source. All names
|
||||
# that get printed must come from this namespace.
|
||||
#
|
||||
# Why can't we re-use node.name? Because it was generated within the
|
||||
# Why can't we reuse node.name? Because it was generated within the
|
||||
# namespace `self._graph_namespace`. In order to provide uniqueness
|
||||
# over both locals (node.name) *and* globals, we create a completely
|
||||
# new namespace to put all identifiers in.
|
||||
@ -1573,7 +1573,7 @@ class Graph:
|
||||
|
||||
# Override Node's repr to generate a valid name within our namespace.
|
||||
# Since repr() is designed to produce a valid Python expression, it
|
||||
# makes sense to re-use it. This way, it's easy to print something like
|
||||
# makes sense to reuse it. This way, it's easy to print something like
|
||||
# Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
|
||||
# implemented cooperatively to allow this.
|
||||
def node_repr(n: Node):
|
||||
|
@ -995,7 +995,7 @@ class {module_name}(torch.nn.Module):
|
||||
@contextlib.contextmanager
|
||||
def _set_replace_hook(self, f):
|
||||
"""
|
||||
Takes a callable which will be called everytime when we replace a node
|
||||
Takes a callable which will be called every time when we replace a node
|
||||
to a new node, or change the node's name. Callable takes three arguments:
|
||||
the old node we're changing, and NAME of the new node, followed by the
|
||||
user node which consumes the old node to be replaced.
|
||||
@ -1009,7 +1009,7 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
def _register_replace_node_hook(self, f):
|
||||
"""
|
||||
Takes a callable which will be called everytime when we replace a node
|
||||
Takes a callable which will be called every time when we replace a node
|
||||
to a new node, or change the node's name. Callable takes three arguments:
|
||||
the old node we're changing, and NAME of the new node, followed by the
|
||||
user node which consumes the old node to be replaced.
|
||||
@ -1019,7 +1019,7 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
def _unregister_replace_node_hook(self, f):
|
||||
"""
|
||||
Takes a callable which was previously registered to be called everytime when we replace a node.
|
||||
Takes a callable which was previously registered to be called every time when we replace a node.
|
||||
This function will unregister that callable so it is no longer invoked on node replacement.
|
||||
"""
|
||||
assert callable(f), "create_node hook must be a callable."
|
||||
|
@ -245,7 +245,7 @@ class Node(_NodeBase):
|
||||
# should not be accessed directly.
|
||||
_input_nodes: dict["Node", None]
|
||||
# All of the nodes that use the value produced by this Node
|
||||
# Note one user may correspond to several uses, e.g. the node fo ``x + x``
|
||||
# Note one user may correspond to several uses, e.g. the node for ``x + x``
|
||||
# would appear once here, but represents two uses.
|
||||
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
|
||||
users: dict["Node", None]
|
||||
|
@ -64,8 +64,8 @@ graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose"
|
||||
# manage to eliminate all float compute, this ends up being equivalent, but
|
||||
# there is a critical difference when some floats cannot be eliminated: when
|
||||
# we call item() on them, what should it's SymFloat be? Ideally, it would
|
||||
# be the same backed SymFloat we had before. But without symbolic expresssion
|
||||
# propogation on tensor quantities, repropagating would instead give you an
|
||||
# be the same backed SymFloat we had before. But without symbolic expression
|
||||
# propagation on tensor quantities, repropagating would instead give you an
|
||||
# unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation
|
||||
# on 0d scalar tensors, but I decided to go for something simpler to start.
|
||||
#
|
||||
@ -346,7 +346,7 @@ def tensorify_python_scalars(
|
||||
# Sometimes by the time we get to tensorify, there have already been
|
||||
# specializations, eg. in python_arg_parser.h. In these cases,
|
||||
# placeholder nodes no longer have a reference to their original
|
||||
# symfloat and thus we need to deduce specializations have happend
|
||||
# symfloat and thus we need to deduce specializations have happened
|
||||
# via shape_env.replacements. NB: there's an important invariant here
|
||||
# that symfloats keep consistent names across restarts.
|
||||
for k, v in shape_env.var_to_val.items():
|
||||
|
@ -88,7 +88,7 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
|
||||
"""
|
||||
# Total num of elements
|
||||
total_num_of_elems = 0
|
||||
# For a module, conside all parameters
|
||||
# For a module, consider all parameters
|
||||
if node.op == "call_module":
|
||||
submodule_dict = dict(fx_module.named_modules())
|
||||
submodule = submodule_dict[node.target]
|
||||
|
@ -48,7 +48,7 @@ class GraphTransformObserver:
|
||||
self.erased_nodes: set[str] = set()
|
||||
self.created_nodes: set[str] = set()
|
||||
self.name_to_node: dict[str, Node] = {}
|
||||
# record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context
|
||||
# record graph modules deepcopied from self.gm, so we can remove hooks on them when exiting the context
|
||||
self.copied_gms: list[GraphModule] = []
|
||||
|
||||
self._node_creation_hook = self.get_node_creation_hook()
|
||||
|
@ -78,7 +78,7 @@ def _topological_sort_passes(
|
||||
if len(constraints) == 0:
|
||||
return passes
|
||||
|
||||
# Contruct a graph mapping nodes to a list of their users
|
||||
# Construct a graph mapping nodes to a list of their users
|
||||
graph: dict[Callable, list[Callable]] = {p: [] for p in passes}
|
||||
indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0)
|
||||
candidates: Queue = Queue()
|
||||
|
@ -95,7 +95,7 @@ class _MinimizerBase:
|
||||
|
||||
Currently we provides two ways to traverse the graph and generate submodules.
|
||||
1. Sequential traversal: this will traverse the graph node by node and generate
|
||||
one submodule with one sigle node.
|
||||
one submodule with one single node.
|
||||
2. Binary searching: this will do a binary search style traversal on the graph.
|
||||
|
||||
For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
|
||||
@ -648,7 +648,7 @@ class _MinimizerBase:
|
||||
) -> NodeSet:
|
||||
"""
|
||||
Traverse topologically sorted node list
|
||||
Find minimium block (start_idx, end_idx) which contains the culprit
|
||||
Find minimum block (start_idx, end_idx) which contains the culprit
|
||||
1st pass: search for end_idx by finding the last node in culprit block
|
||||
where Numerical accuracy (0, end_idx) > threshold
|
||||
2nd pass: search for start_idx by finding the first node in culprit block
|
||||
|
@ -266,7 +266,7 @@ def _get_view_inverse_node_usages(
|
||||
continue
|
||||
self_alias_base = self_alias.meta["view_of"]
|
||||
try:
|
||||
# The we're trying to re-use the args from the view_scatter call inside of the corresponding
|
||||
# The we're trying to reuse the args from the view_scatter call inside of the corresponding
|
||||
# view op, which might throw. This just indicates that view_scatter op isn't a valid inverse
|
||||
# of the current alias we're looking at.
|
||||
view_replay_metadata = original_view(
|
||||
@ -291,7 +291,7 @@ def reinplace(gm, *sample_args):
|
||||
mutating the nodes of the graph.
|
||||
We look for out-of-place op call sites like `b = a.add(...)`,
|
||||
and convert them to be inplace (`b = a.add_(...)`),
|
||||
as long as the input to the current operator ("a") isn't re-used
|
||||
as long as the input to the current operator ("a") isn't reused
|
||||
anywhere later in the graph.
|
||||
|
||||
This pass currently expects to operate on a **functional, ATen** graph.
|
||||
@ -342,7 +342,7 @@ def reinplace(gm, *sample_args):
|
||||
NOTE: there's a future optimization that we should make:
|
||||
if "a" is a (alias of a) program input, but later in the program
|
||||
there is a node that looks like "a.copy_(...)",
|
||||
Then re-inplacing is ok to do - we are temporarily re-using a's buffer,
|
||||
Then re-inplacing is ok to do - we are temporarily reusing a's buffer,
|
||||
which will later be overwritten by the copy_() call.
|
||||
|
||||
This will be an important optimization to have for programs that mutate
|
||||
@ -599,7 +599,7 @@ def reinplace(gm, *sample_args):
|
||||
later_node_usages, self_aliases
|
||||
)
|
||||
|
||||
# Step 2: Check to see if the input to the op is re-used later in the graph.
|
||||
# Step 2: Check to see if the input to the op is reused later in the graph.
|
||||
# If not (same goes for its aliases), then this op is safe to re-in place.
|
||||
# This is a slightly roundabout way to check that there are no later usages of the current self argument.
|
||||
# (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete)
|
||||
|
@ -222,7 +222,7 @@ class SplitResult(NamedTuple):
|
||||
split_module: root module after splitting.
|
||||
submodule_inputs: a dict that maps submodule name to its inputs.
|
||||
non_acc_submodule_prefix: the prefix for non acc submodules. For
|
||||
acc submodule the prefix is alwasy "_run_on_acc_".
|
||||
acc submodule the prefix is always "_run_on_acc_".
|
||||
"""
|
||||
|
||||
split_module: torch.fx.GraphModule
|
||||
|
@ -44,7 +44,7 @@ def topo_sort(nodes: NodeList) -> NodeList:
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def validate_partition(partition: NodeList) -> bool:
|
||||
# verify the partition does't form a dependency cycle in the original graph
|
||||
# verify the partition doesn't form a dependency cycle in the original graph
|
||||
# returns True for valid partition, False for invalid
|
||||
|
||||
partition_set = set(partition)
|
||||
@ -157,13 +157,13 @@ def fuse_as_graphmodule(
|
||||
|
||||
if x in partition_lookup_table:
|
||||
# x is inside subgraph, return the copied node
|
||||
# the node should have been copied aleady, as we are copying graph in the topological order
|
||||
# the node should have been copied already, as we are copying graph in the topological order
|
||||
return node_map[x]
|
||||
|
||||
if x not in node_to_placeholder:
|
||||
# x is not in subgraph, create a new placeholder for subgraph
|
||||
placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
|
||||
# copy all meta fields, even if some fields might be irrelvant for the placeholder node
|
||||
# copy all meta fields, even if some fields might be irrelevant for the placeholder node
|
||||
placeholder_node.meta = copy.copy(x.meta)
|
||||
node_to_placeholder[x] = placeholder_node
|
||||
|
||||
|
@ -317,7 +317,7 @@ class SubgraphMatcher:
|
||||
"""
|
||||
Returns:
|
||||
The matched subgraphs.
|
||||
Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
|
||||
The returned subgraph would be fully self-contained, meaning the nodes (except placeholder
|
||||
and nodes returned by output) can only be consumed by nodes within the matched subgraph.
|
||||
|
||||
Subgraph pattern matcher is implemented with the backtracking style in the following steps:
|
||||
|
Reference in New Issue
Block a user