mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/N] More ruff SIM fixes (#165031)
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fa11f42b1
commit
fbe0d20a17
@ -401,7 +401,7 @@ allow_rnn = False
|
||||
# exported FX graph. This flag should become the default eventually
|
||||
# and be removed, but currently provides a way to fall back to old
|
||||
# graph breaking behavior.
|
||||
capture_sparse_compute = False if is_fbcode() else True
|
||||
capture_sparse_compute = not is_fbcode()
|
||||
|
||||
# If true, error if we try to compile a function that has
|
||||
# been seen before.
|
||||
|
||||
@ -718,11 +718,7 @@ def validate_args_and_maybe_create_graph_inputs(
|
||||
new_proxy = tracer.create_graph_input(
|
||||
arg_name, a.python_type(), example_value
|
||||
)
|
||||
example_value = (
|
||||
node.meta["example_value"]
|
||||
if "example_value" in node.meta
|
||||
else None
|
||||
)
|
||||
example_value = node.meta.get("example_value", None)
|
||||
a = wrap_fx_proxy_cls(
|
||||
target_cls=type(a),
|
||||
tx=tx,
|
||||
@ -760,9 +756,7 @@ def validate_args_and_maybe_create_graph_inputs(
|
||||
# If `a` can be put into a graph
|
||||
elif a.maybe_fx_node() is not None:
|
||||
node = a.maybe_fx_node()
|
||||
example_value = (
|
||||
node.meta["example_value"] if "example_value" in node.meta else None
|
||||
)
|
||||
example_value = node.meta.get("example_value", None)
|
||||
arg_name = node.name if sub_args_names is None else sub_args_names[idx]
|
||||
new_proxy = tracer.create_graph_input(
|
||||
arg_name, a.python_type(), example_value
|
||||
|
||||
@ -280,7 +280,7 @@ backward_pass_autocast = "same_as_forward"
|
||||
|
||||
# This controls whether we collect donated buffer. This flag must be set
|
||||
# False if a user wants to retain_graph=True for backward.
|
||||
donated_buffer = False if is_fbcode() else True
|
||||
donated_buffer = not is_fbcode()
|
||||
|
||||
# Controls the default graph output format used by draw_graph
|
||||
# Supported formats are defined here https://graphviz.org/docs/outputs/
|
||||
|
||||
@ -611,8 +611,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||
# Use position-based lookup for building output
|
||||
# only update the return node args, and remain all other users unchanged
|
||||
output_updated_args = [
|
||||
position_to_quant[i] if i in position_to_quant else node
|
||||
for i, node in enumerate(fwd_outputs)
|
||||
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
|
||||
]
|
||||
# add the scale nodes to the output find the first sym_node in the output
|
||||
idx = find_first_sym_node(output_updated_args)
|
||||
|
||||
@ -482,15 +482,11 @@ def get_wrapper_codegen_for_device(
|
||||
|
||||
|
||||
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
|
||||
return custom_backend_passes[device] if device in custom_backend_passes else None
|
||||
return custom_backend_passes.get(device)
|
||||
|
||||
|
||||
def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
|
||||
return (
|
||||
custom_backend_codegen_configs[device]
|
||||
if device in custom_backend_codegen_configs
|
||||
else None
|
||||
)
|
||||
return custom_backend_codegen_configs.get(device)
|
||||
|
||||
|
||||
@functools.cache
|
||||
|
||||
@ -1262,7 +1262,7 @@ class triton:
|
||||
cudagraph_trees_history_recording = False
|
||||
|
||||
# Enable cudagraph support for mutated inputs from prior cudagraph pool
|
||||
cudagraph_support_input_mutation = False if is_fbcode() else True
|
||||
cudagraph_support_input_mutation = not is_fbcode()
|
||||
|
||||
# Maximal number of allowed cudagraph re-record for a function and
|
||||
# a cudagraph node due to static input tensor address changes or
|
||||
|
||||
@ -476,9 +476,7 @@ def build_subgraph_buffer(
|
||||
elif node.op == "call_function":
|
||||
# For call_function we use the default lowerings and pass in the
|
||||
# already created TensorBoxes as args
|
||||
args, kwargs = tree_map(
|
||||
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
|
||||
)
|
||||
args, kwargs = tree_map(lambda x: env.get(x, x), (node.args, node.kwargs))
|
||||
env[node] = lowerings[node.target](*args, **kwargs)
|
||||
elif node.op == "output":
|
||||
|
||||
@ -692,9 +690,7 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) ->
|
||||
for node in graph.nodes: # preserve the order of nodes
|
||||
if node in subgraph_node_set:
|
||||
subgraph_node_list.append(node)
|
||||
new_node = new_graph.node_copy(
|
||||
node, lambda x: node_remapping[x] if x in node_remapping else x
|
||||
)
|
||||
new_node = new_graph.node_copy(node, lambda x: node_remapping.get(x, x))
|
||||
node_remapping[node] = new_node
|
||||
if node is inner_mm:
|
||||
new_input_anchor = new_node
|
||||
|
||||
@ -531,7 +531,7 @@ def _register_quantized_linear_unary_lowering(
|
||||
)
|
||||
|
||||
# bias
|
||||
b = kwargs["b"] if "b" in kwargs else None
|
||||
b = kwargs.get("b")
|
||||
|
||||
# Output QParams
|
||||
o_inv_scale = kwargs["output_scale"]
|
||||
@ -593,7 +593,7 @@ def _register_quantized_linear_binary_lowering(
|
||||
kwargs["w_zp"],
|
||||
)
|
||||
# bias
|
||||
b = kwargs["b"] if "b" in kwargs else None
|
||||
b = kwargs.get("b")
|
||||
# Output QParams
|
||||
o_inv_scale = kwargs["output_scale"]
|
||||
o_zero_point = kwargs["output_zero_point"]
|
||||
@ -885,10 +885,10 @@ def _register_quantized_maxpool2d_lowering(
|
||||
def qmaxpool2d(match: Match, *args, **kwargs):
|
||||
x = kwargs["x"]
|
||||
kernel_size = kwargs["kernel_size"]
|
||||
stride = kwargs["stride"] if ("stride" in kwargs) else None
|
||||
padding = kwargs["padding"] if ("padding" in kwargs) else 0
|
||||
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
|
||||
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
|
||||
stride = kwargs.get("stride")
|
||||
padding = kwargs.get("padding", 0)
|
||||
dilation = kwargs.get("dilation", 1)
|
||||
ceil_mode = kwargs.get("ceil_mode", False)
|
||||
|
||||
if padding == 0:
|
||||
padding = [0, 0]
|
||||
@ -1976,7 +1976,7 @@ def _register_qlinear_weight_prepack_pass(
|
||||
)
|
||||
|
||||
# Params
|
||||
bias = kwargs["b"] if "b" in kwargs else None
|
||||
bias = kwargs.get("b")
|
||||
|
||||
x_shape = qx.meta.get("tensor_meta").shape
|
||||
if has_free_symbols(x_shape):
|
||||
@ -2451,7 +2451,7 @@ def _register_linear_dynamic_fp16_weight_prepack_pass(
|
||||
# find params
|
||||
x = kwargs["x"]
|
||||
w = kwargs["w"]
|
||||
bias = kwargs["b"] if "b" in kwargs else None
|
||||
bias = kwargs.get("b")
|
||||
|
||||
# find linear node
|
||||
nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
|
||||
@ -2727,7 +2727,7 @@ def _register_smooth_quant_int_mm_pattern():
|
||||
pass_number=pass_number,
|
||||
)
|
||||
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
|
||||
bias = kwargs.get("bias", None)
|
||||
bias = kwargs.get("bias")
|
||||
x = kwargs["a"]
|
||||
weight = kwargs["b"]
|
||||
dtype = kwargs["dtype"]
|
||||
@ -2794,7 +2794,7 @@ def _register_smooth_quant_int_mm_pattern():
|
||||
else:
|
||||
# onednn.qlinear does not support per-channel quantization of x
|
||||
# so in this case, we have to apply x scale and add bias ourselves after qlinear
|
||||
in_shape = kwargs.get("in_shape", None)
|
||||
in_shape = kwargs.get("in_shape")
|
||||
if in_shape is None:
|
||||
x_reshaped = x
|
||||
else:
|
||||
@ -2826,8 +2826,8 @@ def _register_smooth_quant_int_mm_pattern():
|
||||
|
||||
# Add bias and reshape
|
||||
has_outer_reshape = (
|
||||
kwargs.get("out_shape_with_bias", None) is not None
|
||||
or kwargs.get("out_shape_no_bias", None) is not None
|
||||
kwargs.get("out_shape_with_bias") is not None
|
||||
or kwargs.get("out_shape_no_bias") is not None
|
||||
)
|
||||
|
||||
if has_outer_reshape:
|
||||
@ -3276,7 +3276,7 @@ def _register_qlinear_post_op_fusion_pass(
|
||||
)
|
||||
|
||||
# bias
|
||||
b = kwargs["b"] if "b" in kwargs else None
|
||||
b = kwargs.get("b")
|
||||
|
||||
# Output QParams
|
||||
o_inv_scale = (
|
||||
|
||||
@ -1074,13 +1074,13 @@ def _overload_method(func):
|
||||
_check_overload_body(func)
|
||||
qual_name = _qualified_name(func)
|
||||
global _overloaded_methods
|
||||
class_name_map = _overloaded_methods.get(qual_name, None)
|
||||
class_name_map = _overloaded_methods.get(qual_name)
|
||||
if class_name_map is None:
|
||||
class_name_map = {}
|
||||
_overloaded_methods[qual_name] = class_name_map
|
||||
|
||||
class_name, line_no = get_class_name_lineno(func)
|
||||
method_overloads = class_name_map.get(class_name, None)
|
||||
method_overloads = class_name_map.get(class_name)
|
||||
if method_overloads is None:
|
||||
method_overloads = []
|
||||
class_name_map[class_name] = method_overloads
|
||||
@ -1102,7 +1102,7 @@ def _get_overloaded_methods(method, mod_class):
|
||||
if not hasattr(method, "__name__"):
|
||||
return None
|
||||
qual_name = _qualified_name(method)
|
||||
class_name_map = _overloaded_methods.get(qual_name, None)
|
||||
class_name_map = _overloaded_methods.get(qual_name)
|
||||
if class_name_map is None:
|
||||
return None
|
||||
overloads = class_name_map.get(mod_class.__name__, None)
|
||||
|
||||
@ -5307,7 +5307,7 @@ def grid_sampler_3d_backward(
|
||||
|
||||
@register_meta([aten.full.default])
|
||||
def full(size, fill_value, *args, **kwargs):
|
||||
dtype = kwargs.get("dtype", None)
|
||||
dtype = kwargs.get("dtype")
|
||||
if not dtype:
|
||||
dtype = utils.get_dtype(fill_value)
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
@ -1409,7 +1409,7 @@ class _HigherOrderNamespace(types.ModuleType):
|
||||
|
||||
def __getattr__(self, name: str) -> HigherOrderOperator:
|
||||
# Following _OpNamespace.__getattr__, we cache the op on this object.
|
||||
op = _higher_order_ops.get(name, None)
|
||||
op = _higher_order_ops.get(name)
|
||||
if op is None:
|
||||
raise AttributeError(
|
||||
f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"
|
||||
|
||||
@ -87,13 +87,9 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||
# for capturing `.item` operations
|
||||
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
|
||||
"quant_min", None
|
||||
)
|
||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
|
||||
"quant_max", None
|
||||
)
|
||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
|
||||
|
||||
def get_weight(self):
|
||||
"""
|
||||
|
||||
@ -240,29 +240,29 @@ scale_min_lower_bound=None, scale_max_upper_bound=None)
|
||||
"bias_type": torch.dtype
|
||||
"is_dynamic": bool
|
||||
"""
|
||||
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
|
||||
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY)
|
||||
if input_dtype is not None and not isinstance(
|
||||
input_dtype, (torch.dtype, DTypeWithConstraints)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected input_dtype to be a torch.dtype or DTypeWithConstraints"
|
||||
)
|
||||
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
|
||||
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY)
|
||||
if output_dtype is not None and not isinstance(
|
||||
output_dtype, (torch.dtype, DTypeWithConstraints)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected output_dtype to be a torch.dtype or DTypeWithConstraints"
|
||||
)
|
||||
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
|
||||
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY)
|
||||
if weight_dtype is not None and not isinstance(
|
||||
weight_dtype, (torch.dtype, DTypeWithConstraints)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected weight_dtype to be a torch.dtype or DTypeWithConstraints"
|
||||
)
|
||||
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
|
||||
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
|
||||
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY)
|
||||
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY)
|
||||
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
@ -673,23 +673,23 @@ class BackendPatternConfig:
|
||||
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
|
||||
conf.add_dtype_config(_get_dtype_config(d))
|
||||
conf.set_root_module(
|
||||
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type]
|
||||
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY)) # type: ignore[arg-type]
|
||||
conf.set_reference_quantized_module(
|
||||
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_fused_module(
|
||||
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_fuser_method(
|
||||
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_root_node_getter(
|
||||
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_extra_inputs_getter(
|
||||
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_num_tensor_args_to_observation_type(
|
||||
backend_pattern_config_dict.get(
|
||||
|
||||
@ -286,7 +286,7 @@ def get_fuser_method_new(
|
||||
op_patterns = _get_valid_patterns(op_pattern)
|
||||
fuser_method = None
|
||||
for op_pattern in op_patterns:
|
||||
fuser_method = fuser_method_mapping.get(op_pattern, None)
|
||||
fuser_method = fuser_method_mapping.get(op_pattern)
|
||||
if fuser_method is not None:
|
||||
break
|
||||
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
|
||||
|
||||
@ -168,7 +168,7 @@ def _find_matches(
|
||||
for node in reversed(graph.nodes):
|
||||
if node.name not in match_map and node.name not in all_matched:
|
||||
for pattern, quantize_handler_cls in patterns.items():
|
||||
root_node_getter = root_node_getter_mapping.get(pattern, None)
|
||||
root_node_getter = root_node_getter_mapping.get(pattern)
|
||||
if _is_match(modules, node, pattern) and node.name not in match_map:
|
||||
matched_node_pattern: list[Node] = []
|
||||
record_match(pattern, node, node, matched_node_pattern, match_map)
|
||||
|
||||
@ -130,7 +130,7 @@ def _get_qspec_for_arg(
|
||||
) -> Optional[QuantizationSpecBase]:
|
||||
while _is_activation_post_process_node(arg, named_modules):
|
||||
arg = arg.args[0] # type: ignore[assignment]
|
||||
return input_qspec_map.get(arg, None)
|
||||
return input_qspec_map.get(arg)
|
||||
|
||||
|
||||
def _create_obs_or_fq_from_qspec(
|
||||
|
||||
@ -164,7 +164,7 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
|
||||
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
|
||||
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
|
||||
}
|
||||
prepack_op = prepack_ops.get(conv_op, None)
|
||||
prepack_op = prepack_ops.get(conv_op)
|
||||
assert prepack_op, f"Didn't find prepack op for {conv_op}"
|
||||
return prepack_op
|
||||
|
||||
|
||||
@ -806,7 +806,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
):
|
||||
version = local_metadata.get("version", None)
|
||||
version = local_metadata.get("version")
|
||||
if version is not None and version < 3:
|
||||
local_state = ["min_vals", "max_vals"]
|
||||
expected_min_name = "min_vals"
|
||||
|
||||
@ -366,7 +366,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||||
if input_edge_obs_or_fq is None:
|
||||
return new_arg
|
||||
|
||||
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
|
||||
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg)
|
||||
# the arg is observed as the output and is using the same instance as the input_edge
|
||||
# we'll reuse the inserted observer/fake_quant
|
||||
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
|
||||
@ -497,11 +497,7 @@ def _maybe_insert_input_and_output_observers_for_node(
|
||||
is_qat: bool,
|
||||
model_device: Optional[torch.device] = None,
|
||||
):
|
||||
this_node_quantization_annotation = (
|
||||
node.meta["quantization_annotation"]
|
||||
if "quantization_annotation" in node.meta
|
||||
else None
|
||||
)
|
||||
this_node_quantization_annotation = node.meta.get("quantization_annotation", None)
|
||||
if this_node_quantization_annotation is None:
|
||||
return
|
||||
|
||||
|
||||
@ -343,7 +343,7 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
|
||||
# TODO: merge with get_static_quant_module_class
|
||||
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
|
||||
"""Get the quantized operator corresponding to the float operator"""
|
||||
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
|
||||
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
|
||||
assert quantized_op is not None, (
|
||||
f"Operator {str(float_op)} does not have corresponding quantized op"
|
||||
)
|
||||
|
||||
@ -1357,11 +1357,7 @@ class X86InductorQuantizer(Quantizer):
|
||||
def _annotate_output_share_observer_as_input(
|
||||
self, input_node: Node, source_node: Node
|
||||
):
|
||||
source_node_quantization_annotation = (
|
||||
source_node.meta[QUANT_ANNOTATION_KEY]
|
||||
if QUANT_ANNOTATION_KEY in source_node.meta
|
||||
else None
|
||||
)
|
||||
source_node_quantization_annotation = source_node.meta.get(QUANT_ANNOTATION_KEY)
|
||||
if (
|
||||
source_node_quantization_annotation
|
||||
and source_node_quantization_annotation._is_output_of_quantized_pattern
|
||||
@ -1400,10 +1396,8 @@ class X86InductorQuantizer(Quantizer):
|
||||
return
|
||||
|
||||
# Get the quantization_annotation from getitem_node
|
||||
maxpool_node_quantization_annotation = (
|
||||
maxpool_node.meta[QUANT_ANNOTATION_KEY]
|
||||
if QUANT_ANNOTATION_KEY in maxpool_node.meta
|
||||
else None
|
||||
maxpool_node_quantization_annotation = maxpool_node.meta.get(
|
||||
QUANT_ANNOTATION_KEY
|
||||
)
|
||||
if (
|
||||
maxpool_node_quantization_annotation
|
||||
|
||||
@ -159,10 +159,7 @@ class EventList(list):
|
||||
if p is not None:
|
||||
assert p.fwd_thread is not None
|
||||
t = (p.sequence_nr, p.fwd_thread)
|
||||
if t in fwd_stacks:
|
||||
evt.stack = fwd_stacks[t]
|
||||
else:
|
||||
evt.stack = []
|
||||
evt.stack = fwd_stacks.get(t, [])
|
||||
|
||||
@property
|
||||
def self_cpu_time_total(self):
|
||||
|
||||
@ -214,7 +214,7 @@ def replicate(
|
||||
|
||||
state = replicate.state(module)
|
||||
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
|
||||
device_mesh = kwargs.get("device_mesh", None)
|
||||
device_mesh = kwargs.get("device_mesh")
|
||||
if device_mesh is not None:
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
|
||||
@ -228,7 +228,7 @@ def replicate_impl(
|
||||
# Place Replicate leftmost for highest priority in the method resolution order
|
||||
for module in modules:
|
||||
cls = module.__class__
|
||||
new_cls = cls_to_replicate_cls.get(cls, None)
|
||||
new_cls = cls_to_replicate_cls.get(cls)
|
||||
if not new_cls:
|
||||
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
||||
new_cls = type(f"Replicate{cls.__name__}", (ReplicateModule, cls), dct)
|
||||
|
||||
@ -143,7 +143,7 @@ def register_tensor_creation_op(op):
|
||||
takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
|
||||
``torch.full_like``.
|
||||
"""
|
||||
creation_op = tensor_like_creation_op_map.get(op, None)
|
||||
creation_op = tensor_like_creation_op_map.get(op)
|
||||
if creation_op is None:
|
||||
raise RuntimeError(f"Tensor creation {op} not supported!")
|
||||
if kwargs is None:
|
||||
|
||||
@ -678,7 +678,7 @@ class ShardedTensor(ShardedTensorBase):
|
||||
copy_tensor = kwargs.get("copy", False)
|
||||
non_blocking = kwargs.get("non_blocking", False)
|
||||
memory_format = kwargs.get("memory_format", torch.preserve_format)
|
||||
process_group = kwargs.get("process_group", None)
|
||||
process_group = kwargs.get("process_group")
|
||||
|
||||
if (
|
||||
not copy_tensor
|
||||
|
||||
@ -605,7 +605,7 @@ def _distribute_tensors(
|
||||
if pg is None:
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
for key in keys:
|
||||
_local_state = local_state_dict.get(key, None)
|
||||
_local_state = local_state_dict.get(key)
|
||||
if _local_state is None or torch.is_tensor(_local_state):
|
||||
continue
|
||||
|
||||
|
||||
@ -127,7 +127,7 @@ def aggregate_stats(
|
||||
}
|
||||
|
||||
for mod in model.modules():
|
||||
if mod_mem_stat := mod_mem_stats.get(mod, None):
|
||||
if mod_mem_stat := mod_mem_stats.get(mod):
|
||||
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
|
||||
sac_runtime = tradeoff_stats.sac_runtime
|
||||
sac_memory = tradeoff_stats.sac_memory
|
||||
|
||||
@ -711,7 +711,7 @@ class SACEstimator(TorchDispatchMode):
|
||||
str(i in sac_stats.view_like_ops),
|
||||
str(i in sac_stats.rand_ops),
|
||||
str(i in sac_stats.saved_autograd_ops),
|
||||
str(op_parent.get(i, None)),
|
||||
str(op_parent.get(i)),
|
||||
]
|
||||
table_data.append(row)
|
||||
# Define headers
|
||||
|
||||
@ -107,7 +107,7 @@ def auto_quantize(func, qtype, quant_loss=None):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
group = kwargs.get("group", None)
|
||||
group = kwargs.get("group")
|
||||
async_op = kwargs.get("async_op", False)
|
||||
if async_op is True:
|
||||
raise RuntimeError("The async_op=True mode is not supported yet.")
|
||||
@ -133,8 +133,8 @@ def auto_quantize(func, qtype, quant_loss=None):
|
||||
|
||||
elif func == dist.all_to_all_single:
|
||||
tensors = args[0]
|
||||
out_splits = kwargs.get("out_splits", None)
|
||||
in_splits = kwargs.get("in_splits", None)
|
||||
out_splits = kwargs.get("out_splits")
|
||||
in_splits = kwargs.get("in_splits")
|
||||
# Quantizing the input/output tensor
|
||||
input_tensors = _quantize_tensor(args[1], qtype)
|
||||
out_tensors = _quantize_tensor(tensors, qtype)
|
||||
|
||||
@ -631,7 +631,7 @@ class _FileSystemWriter(StorageWriter):
|
||||
def set_up_storage_writer(
|
||||
self, is_coordinator: bool, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
self.rank = kwargs.get("rank", None)
|
||||
self.rank = kwargs.get("rank")
|
||||
self.use_collectives = kwargs.get("use_collectives", True)
|
||||
|
||||
def _metadata_exists(self) -> bool:
|
||||
@ -919,7 +919,7 @@ class FileSystemReader(StorageReader):
|
||||
|
||||
# Implementing the abstract function in StorageReader
|
||||
def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
|
||||
rank = kwargs.get("rank", None)
|
||||
rank = kwargs.get("rank")
|
||||
path = self._get_metadata_path(rank)
|
||||
with self.fs.create_stream(path, "rb") as metadata_file:
|
||||
metadata = pickle.load(metadata_file)
|
||||
@ -934,7 +934,7 @@ class FileSystemReader(StorageReader):
|
||||
self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
self.storage_data = metadata.storage_data
|
||||
self.rank = kwargs.get("rank", None)
|
||||
self.rank = kwargs.get("rank")
|
||||
self.use_collectives = kwargs.get("use_collectives", True)
|
||||
assert self.storage_data is not None
|
||||
|
||||
|
||||
@ -31,11 +31,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
|
||||
msg_dict = {}
|
||||
|
||||
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly
|
||||
storage_writer = kwargs.get("storage_writer", None)
|
||||
storage_reader = kwargs.get("storage_reader", None)
|
||||
planner = kwargs.get("planner", None)
|
||||
storage_writer = kwargs.get("storage_writer")
|
||||
storage_reader = kwargs.get("storage_reader")
|
||||
planner = kwargs.get("planner")
|
||||
|
||||
checkpoint_id = kwargs.get("checkpoint_id", None)
|
||||
checkpoint_id = kwargs.get("checkpoint_id")
|
||||
if not checkpoint_id and (serializer := storage_writer or storage_reader):
|
||||
# pyrefly: ignore # unbound-name
|
||||
checkpoint_id = getattr(serializer, "checkpoint_id", None)
|
||||
|
||||
@ -307,7 +307,7 @@ def _verify_options(
|
||||
continue
|
||||
|
||||
fqns = _get_fqns(model, name)
|
||||
fqn = fqn_param_mapping.get(param, None)
|
||||
fqn = fqn_param_mapping.get(param)
|
||||
if fqn is not None:
|
||||
cast(set[str], fqn_param_mapping[param]).update(fqns)
|
||||
shared_params_mapping[param] = fqn_param_mapping[param]
|
||||
|
||||
@ -5081,7 +5081,7 @@ def _is_safe_to_split() -> bool:
|
||||
users must be aware that a pg is only splittable after the first collective is
|
||||
issued.
|
||||
"""
|
||||
return False if _get_default_group().bound_device_id is None else True
|
||||
return _get_default_group().bound_device_id is not None
|
||||
|
||||
|
||||
@_time_logger
|
||||
|
||||
@ -88,10 +88,7 @@ def configure(handler: MetricHandler, group: Optional[str] = None):
|
||||
|
||||
|
||||
def getStream(group: str):
|
||||
if group in _metrics_map:
|
||||
handler = _metrics_map[group]
|
||||
else:
|
||||
handler = _default_metrics_handler
|
||||
handler = _metrics_map.get(group, _default_metrics_handler)
|
||||
return MetricStream(group, handler)
|
||||
|
||||
|
||||
|
||||
@ -241,7 +241,7 @@ def fully_shard(
|
||||
# Place FSDP leftmost for highest priority in the method resolution order
|
||||
for module in modules:
|
||||
cls = module.__class__
|
||||
new_cls = cls_to_fsdp_cls.get(cls, None)
|
||||
new_cls = cls_to_fsdp_cls.get(cls)
|
||||
if not new_cls:
|
||||
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
||||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
||||
|
||||
@ -1270,7 +1270,7 @@ def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool:
|
||||
(which usually are FQNs) versus integers (which usually refer to param_ids
|
||||
from a vanilla torch.optim.Optimizer).
|
||||
"""
|
||||
state = optim_state_dict.get("state", None)
|
||||
state = optim_state_dict.get("state")
|
||||
if not state:
|
||||
# If we cannot find a state, assume it is not NamedOptimizer as
|
||||
# NamedOptimizer has eager initialization.
|
||||
@ -1718,7 +1718,7 @@ def _convert_state_with_orig_params(
|
||||
# across ranks
|
||||
for optim_state_key in all_optim_state_keys:
|
||||
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
|
||||
optim_state_key, None
|
||||
optim_state_key
|
||||
)
|
||||
|
||||
if param_key is None and not optim_state_key.is_fsdp_managed:
|
||||
@ -1726,7 +1726,7 @@ def _convert_state_with_orig_params(
|
||||
|
||||
if optim_state_key.is_fsdp_managed:
|
||||
fqn = optim_state_key.unflat_param_names[0]
|
||||
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
|
||||
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn)
|
||||
if fsdp_param_info is None:
|
||||
# This can happen if the not all FSDP instances have all the
|
||||
# parameters. This can happen with FSDP + some MPMD style
|
||||
@ -1804,7 +1804,7 @@ def _convert_state_with_flat_params(
|
||||
# across ranks
|
||||
for optim_state_key in all_optim_state_keys:
|
||||
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
|
||||
optim_state_key, None
|
||||
optim_state_key
|
||||
)
|
||||
|
||||
assert param_key is not None, (
|
||||
|
||||
@ -52,7 +52,7 @@ class _ScriptLocalOptimizer(nn.Module):
|
||||
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
|
||||
# apply functional optimizer step with a list of gradients
|
||||
grads: list[Optional[Tensor]] = [
|
||||
all_local_grads[p] if p in all_local_grads else None
|
||||
all_local_grads[p] if p in all_local_grads else None # noqa: SIM401
|
||||
for p in self._local_params
|
||||
]
|
||||
|
||||
|
||||
@ -189,7 +189,7 @@ def _insert_stage_symbolic_backward(
|
||||
output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
|
||||
if node in tuples:
|
||||
stage_output = tuples[node]
|
||||
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
|
||||
output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
|
||||
outputs_with_grads_idxs = [
|
||||
i for i, n in enumerate(tuples[node]) if n in live_nodes
|
||||
]
|
||||
|
||||
@ -114,7 +114,7 @@ def get_param_groups(
|
||||
"intermediates": intersected,
|
||||
}
|
||||
for input_node in intersected:
|
||||
existing = param_groups.get(input_node, None)
|
||||
existing = param_groups.get(input_node)
|
||||
if existing is not None:
|
||||
existing["params"] = existing["params"].union(param_group["params"])
|
||||
existing["intermediates"] = existing["intermediates"].union(
|
||||
|
||||
@ -326,8 +326,7 @@ def _insert_copy_for_mutations(
|
||||
return_nodes_to_copy[return_node] = copy_node
|
||||
|
||||
output_args = tuple(
|
||||
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
|
||||
for node in user_output_nodes
|
||||
return_nodes_to_copy.get(node, node) for node in user_output_nodes
|
||||
)
|
||||
with gm.graph.inserting_before(output_node):
|
||||
# Only return user outputs
|
||||
|
||||
@ -46,7 +46,7 @@ class NormalizeArgs(Transformer):
|
||||
|
||||
def get_type(arg):
|
||||
if isinstance(arg, fx.Node):
|
||||
return n.meta["type"] if "type" in n.meta else None
|
||||
return n.meta.get("type")
|
||||
return type(arg)
|
||||
|
||||
arg_types = map_aggregate(n.args, get_type)
|
||||
|
||||
@ -4414,7 +4414,7 @@ class ShapeEnv:
|
||||
size = []
|
||||
for i, val in enumerate(tensor_size):
|
||||
sym = self.create_symbol(
|
||||
val if i not in hint_overrides else hint_overrides[i],
|
||||
hint_overrides.get(i, val),
|
||||
TensorPropertySource(source, TensorProperty.SIZE, i),
|
||||
dynamic_dims[i],
|
||||
constraint_dims[i],
|
||||
@ -4615,7 +4615,7 @@ class ShapeEnv:
|
||||
sym_sizes = [
|
||||
self.create_symintnode(
|
||||
sym,
|
||||
hint=hint if i not in hint_overrides else hint_overrides[i],
|
||||
hint=hint_overrides.get(i, hint),
|
||||
source=TensorPropertySource(source, TensorProperty.SIZE, i),
|
||||
)
|
||||
for i, (sym, hint) in enumerate(zip(size, ex_size))
|
||||
|
||||
@ -930,7 +930,7 @@ class ParameterDict(Module):
|
||||
key (str): key to get from the ParameterDict
|
||||
default (Parameter, optional): value to return if key not present
|
||||
"""
|
||||
return self[key] if key in self else default
|
||||
return self[key] if key in self else default # noqa: SIM401
|
||||
|
||||
def fromkeys(
|
||||
self, keys: Iterable[str], default: Optional[Any] = None
|
||||
|
||||
@ -1761,11 +1761,7 @@ class Module:
|
||||
if recording_scopes:
|
||||
# type ignore was added because at this point one knows that
|
||||
# torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
|
||||
name = (
|
||||
torch.jit._trace._trace_module_map[self] # type: ignore[index]
|
||||
if self in torch.jit._trace._trace_module_map # type: ignore[operator]
|
||||
else None
|
||||
) # noqa: B950
|
||||
name = torch.jit._trace._trace_module_map.get(self, None) # type: ignore[operator, union-attr]
|
||||
if name:
|
||||
tracing_state.push_scope(name)
|
||||
else:
|
||||
|
||||
@ -218,7 +218,7 @@ def _dump_DDP_relevant_env_vars():
|
||||
]
|
||||
formatted_output = ""
|
||||
for var in relevant_env_vars:
|
||||
value = os.environ[var] if var in os.environ else "N/A"
|
||||
value = os.environ.get(var, "N/A")
|
||||
formatted_output += f"env:{var}={value}\n"
|
||||
print(formatted_output)
|
||||
|
||||
|
||||
@ -783,8 +783,8 @@ class Optimizer:
|
||||
assert param_groups is not None
|
||||
for pg in param_groups:
|
||||
if param_id in pg["params"]:
|
||||
fused = pg["fused"] if "fused" in pg else False
|
||||
capturable = pg["capturable"] if "capturable" in pg else False
|
||||
fused = pg.get("fused", False)
|
||||
capturable = pg.get("capturable", False)
|
||||
break
|
||||
if key == "step":
|
||||
if capturable or fused:
|
||||
|
||||
@ -390,8 +390,8 @@ class DeviceTypeTestBase(TestCase):
|
||||
return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
|
||||
|
||||
def _apply_precision_override_for_test(self, test, param_kwargs):
|
||||
dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
|
||||
dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
|
||||
dtype = param_kwargs.get("dtype")
|
||||
dtype = param_kwargs.get("dtypes", dtype)
|
||||
if dtype:
|
||||
self.precision = self._get_precision_override(test, dtype)
|
||||
self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)
|
||||
|
||||
@ -1915,7 +1915,7 @@ def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
|
||||
for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
|
||||
# The scalar we are passing to new_full must be the same dtype
|
||||
# as the one of the resulting tensor
|
||||
use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
|
||||
use_dtype = sample.kwargs.get('dtype', dtype)
|
||||
yield SampleInput(
|
||||
sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)
|
||||
|
||||
|
||||
@ -725,7 +725,7 @@ class DistributedTest:
|
||||
lines = out.getvalue().splitlines()
|
||||
|
||||
def format_line(var):
|
||||
return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
|
||||
return f"env:{var}={os.environ.get(var, 'N/A')}"
|
||||
|
||||
# Check relevant env vars
|
||||
vars = [
|
||||
@ -6212,7 +6212,7 @@ class DistributedTest:
|
||||
)
|
||||
def test_ddp_logging_data_cpu(self):
|
||||
def parse_env(var):
|
||||
return os.environ[var] if var in os.environ else "N/A"
|
||||
return os.environ.get(var, "N/A")
|
||||
|
||||
dist.set_debug_level(dist.DebugLevel.INFO)
|
||||
_, group_id, _ = self._init_global_test()
|
||||
|
||||
@ -21,7 +21,7 @@ INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
|
||||
|
||||
|
||||
def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
|
||||
return _MIRROR_REL_OP.get(type, None)
|
||||
return _MIRROR_REL_OP.get(type)
|
||||
|
||||
|
||||
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
|
||||
|
||||
@ -277,7 +277,7 @@ def create_graph(objects, *, context=None, filter=None):
|
||||
references = annotated_references(obj)
|
||||
for referrent in gc.get_referents(obj):
|
||||
rid = id(referrent)
|
||||
tidx = id_to_node.get(rid, None)
|
||||
tidx = id_to_node.get(rid)
|
||||
if tidx is None:
|
||||
continue
|
||||
labels = references.get(rid, ["?"])
|
||||
|
||||
Reference in New Issue
Block a user