[2/N] More ruff SIM fixes (#165031)

This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031
Approved by: https://github.com/mlazos
This commit is contained in:
Yuanyuan Chen
2025-10-14 14:22:49 +00:00
committed by PyTorch MergeBot
parent 1fa11f42b1
commit fbe0d20a17
52 changed files with 98 additions and 138 deletions

View File

@ -401,7 +401,7 @@ allow_rnn = False
# exported FX graph. This flag should become the default eventually
# and be removed, but currently provides a way to fall back to old
# graph breaking behavior.
capture_sparse_compute = False if is_fbcode() else True
capture_sparse_compute = not is_fbcode()
# If true, error if we try to compile a function that has
# been seen before.

View File

@ -718,11 +718,7 @@ def validate_args_and_maybe_create_graph_inputs(
new_proxy = tracer.create_graph_input(
arg_name, a.python_type(), example_value
)
example_value = (
node.meta["example_value"]
if "example_value" in node.meta
else None
)
example_value = node.meta.get("example_value", None)
a = wrap_fx_proxy_cls(
target_cls=type(a),
tx=tx,
@ -760,9 +756,7 @@ def validate_args_and_maybe_create_graph_inputs(
# If `a` can be put into a graph
elif a.maybe_fx_node() is not None:
node = a.maybe_fx_node()
example_value = (
node.meta["example_value"] if "example_value" in node.meta else None
)
example_value = node.meta.get("example_value", None)
arg_name = node.name if sub_args_names is None else sub_args_names[idx]
new_proxy = tracer.create_graph_input(
arg_name, a.python_type(), example_value

View File

@ -280,7 +280,7 @@ backward_pass_autocast = "same_as_forward"
# This controls whether we collect donated buffer. This flag must be set
# False if a user wants to retain_graph=True for backward.
donated_buffer = False if is_fbcode() else True
donated_buffer = not is_fbcode()
# Controls the default graph output format used by draw_graph
# Supported formats are defined here https://graphviz.org/docs/outputs/

View File

@ -611,8 +611,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
# Use position-based lookup for building output
# only update the return node args, and remain all other users unchanged
output_updated_args = [
position_to_quant[i] if i in position_to_quant else node
for i, node in enumerate(fwd_outputs)
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
]
# add the scale nodes to the output find the first sym_node in the output
idx = find_first_sym_node(output_updated_args)

View File

@ -482,15 +482,11 @@ def get_wrapper_codegen_for_device(
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
return custom_backend_passes[device] if device in custom_backend_passes else None
return custom_backend_passes.get(device)
def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
return (
custom_backend_codegen_configs[device]
if device in custom_backend_codegen_configs
else None
)
return custom_backend_codegen_configs.get(device)
@functools.cache

View File

@ -1262,7 +1262,7 @@ class triton:
cudagraph_trees_history_recording = False
# Enable cudagraph support for mutated inputs from prior cudagraph pool
cudagraph_support_input_mutation = False if is_fbcode() else True
cudagraph_support_input_mutation = not is_fbcode()
# Maximal number of allowed cudagraph re-record for a function and
# a cudagraph node due to static input tensor address changes or

View File

@ -476,9 +476,7 @@ def build_subgraph_buffer(
elif node.op == "call_function":
# For call_function we use the default lowerings and pass in the
# already created TensorBoxes as args
args, kwargs = tree_map(
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
)
args, kwargs = tree_map(lambda x: env.get(x, x), (node.args, node.kwargs))
env[node] = lowerings[node.target](*args, **kwargs)
elif node.op == "output":
@ -692,9 +690,7 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) ->
for node in graph.nodes: # preserve the order of nodes
if node in subgraph_node_set:
subgraph_node_list.append(node)
new_node = new_graph.node_copy(
node, lambda x: node_remapping[x] if x in node_remapping else x
)
new_node = new_graph.node_copy(node, lambda x: node_remapping.get(x, x))
node_remapping[node] = new_node
if node is inner_mm:
new_input_anchor = new_node

View File

@ -531,7 +531,7 @@ def _register_quantized_linear_unary_lowering(
)
# bias
b = kwargs["b"] if "b" in kwargs else None
b = kwargs.get("b")
# Output QParams
o_inv_scale = kwargs["output_scale"]
@ -593,7 +593,7 @@ def _register_quantized_linear_binary_lowering(
kwargs["w_zp"],
)
# bias
b = kwargs["b"] if "b" in kwargs else None
b = kwargs.get("b")
# Output QParams
o_inv_scale = kwargs["output_scale"]
o_zero_point = kwargs["output_zero_point"]
@ -885,10 +885,10 @@ def _register_quantized_maxpool2d_lowering(
def qmaxpool2d(match: Match, *args, **kwargs):
x = kwargs["x"]
kernel_size = kwargs["kernel_size"]
stride = kwargs["stride"] if ("stride" in kwargs) else None
padding = kwargs["padding"] if ("padding" in kwargs) else 0
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
stride = kwargs.get("stride")
padding = kwargs.get("padding", 0)
dilation = kwargs.get("dilation", 1)
ceil_mode = kwargs.get("ceil_mode", False)
if padding == 0:
padding = [0, 0]
@ -1976,7 +1976,7 @@ def _register_qlinear_weight_prepack_pass(
)
# Params
bias = kwargs["b"] if "b" in kwargs else None
bias = kwargs.get("b")
x_shape = qx.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
@ -2451,7 +2451,7 @@ def _register_linear_dynamic_fp16_weight_prepack_pass(
# find params
x = kwargs["x"]
w = kwargs["w"]
bias = kwargs["b"] if "b" in kwargs else None
bias = kwargs.get("b")
# find linear node
nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
@ -2727,7 +2727,7 @@ def _register_smooth_quant_int_mm_pattern():
pass_number=pass_number,
)
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
bias = kwargs.get("bias", None)
bias = kwargs.get("bias")
x = kwargs["a"]
weight = kwargs["b"]
dtype = kwargs["dtype"]
@ -2794,7 +2794,7 @@ def _register_smooth_quant_int_mm_pattern():
else:
# onednn.qlinear does not support per-channel quantization of x
# so in this case, we have to apply x scale and add bias ourselves after qlinear
in_shape = kwargs.get("in_shape", None)
in_shape = kwargs.get("in_shape")
if in_shape is None:
x_reshaped = x
else:
@ -2826,8 +2826,8 @@ def _register_smooth_quant_int_mm_pattern():
# Add bias and reshape
has_outer_reshape = (
kwargs.get("out_shape_with_bias", None) is not None
or kwargs.get("out_shape_no_bias", None) is not None
kwargs.get("out_shape_with_bias") is not None
or kwargs.get("out_shape_no_bias") is not None
)
if has_outer_reshape:
@ -3276,7 +3276,7 @@ def _register_qlinear_post_op_fusion_pass(
)
# bias
b = kwargs["b"] if "b" in kwargs else None
b = kwargs.get("b")
# Output QParams
o_inv_scale = (

View File

@ -1074,13 +1074,13 @@ def _overload_method(func):
_check_overload_body(func)
qual_name = _qualified_name(func)
global _overloaded_methods
class_name_map = _overloaded_methods.get(qual_name, None)
class_name_map = _overloaded_methods.get(qual_name)
if class_name_map is None:
class_name_map = {}
_overloaded_methods[qual_name] = class_name_map
class_name, line_no = get_class_name_lineno(func)
method_overloads = class_name_map.get(class_name, None)
method_overloads = class_name_map.get(class_name)
if method_overloads is None:
method_overloads = []
class_name_map[class_name] = method_overloads
@ -1102,7 +1102,7 @@ def _get_overloaded_methods(method, mod_class):
if not hasattr(method, "__name__"):
return None
qual_name = _qualified_name(method)
class_name_map = _overloaded_methods.get(qual_name, None)
class_name_map = _overloaded_methods.get(qual_name)
if class_name_map is None:
return None
overloads = class_name_map.get(mod_class.__name__, None)

View File

@ -5307,7 +5307,7 @@ def grid_sampler_3d_backward(
@register_meta([aten.full.default])
def full(size, fill_value, *args, **kwargs):
dtype = kwargs.get("dtype", None)
dtype = kwargs.get("dtype")
if not dtype:
dtype = utils.get_dtype(fill_value)
kwargs["dtype"] = dtype

View File

@ -1409,7 +1409,7 @@ class _HigherOrderNamespace(types.ModuleType):
def __getattr__(self, name: str) -> HigherOrderOperator:
# Following _OpNamespace.__getattr__, we cache the op on this object.
op = _higher_order_ops.get(name, None)
op = _higher_order_ops.get(name)
if op is None:
raise AttributeError(
f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"

View File

@ -87,13 +87,9 @@ class ReferenceQuantizedModule(torch.nn.Module):
# for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
# pyrefly: ignore # bad-assignment
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
"quant_min", None
)
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
# pyrefly: ignore # bad-assignment
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
"quant_max", None
)
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
def get_weight(self):
"""

View File

@ -240,29 +240,29 @@ scale_min_lower_bound=None, scale_max_upper_bound=None)
"bias_type": torch.dtype
"is_dynamic": bool
"""
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY)
if input_dtype is not None and not isinstance(
input_dtype, (torch.dtype, DTypeWithConstraints)
):
raise ValueError(
"Expected input_dtype to be a torch.dtype or DTypeWithConstraints"
)
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY)
if output_dtype is not None and not isinstance(
output_dtype, (torch.dtype, DTypeWithConstraints)
):
raise ValueError(
"Expected output_dtype to be a torch.dtype or DTypeWithConstraints"
)
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY)
if weight_dtype is not None and not isinstance(
weight_dtype, (torch.dtype, DTypeWithConstraints)
):
raise ValueError(
"Expected weight_dtype to be a torch.dtype or DTypeWithConstraints"
)
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY)
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY)
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
def to_dict(self) -> dict[str, Any]:
@ -673,23 +673,23 @@ class BackendPatternConfig:
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
conf.add_dtype_config(_get_dtype_config(d))
conf.set_root_module(
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type]
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY) # type: ignore[arg-type]
)
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type]
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY)) # type: ignore[arg-type]
conf.set_reference_quantized_module(
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY) # type: ignore[arg-type]
)
conf.set_fused_module(
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY) # type: ignore[arg-type]
)
conf.set_fuser_method(
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type]
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY) # type: ignore[arg-type]
)
conf._set_root_node_getter(
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type]
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY) # type: ignore[arg-type]
)
conf._set_extra_inputs_getter(
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type]
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY) # type: ignore[arg-type]
)
conf._set_num_tensor_args_to_observation_type(
backend_pattern_config_dict.get(

View File

@ -286,7 +286,7 @@ def get_fuser_method_new(
op_patterns = _get_valid_patterns(op_pattern)
fuser_method = None
for op_pattern in op_patterns:
fuser_method = fuser_method_mapping.get(op_pattern, None)
fuser_method = fuser_method_mapping.get(op_pattern)
if fuser_method is not None:
break
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "

View File

@ -168,7 +168,7 @@ def _find_matches(
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, quantize_handler_cls in patterns.items():
root_node_getter = root_node_getter_mapping.get(pattern, None)
root_node_getter = root_node_getter_mapping.get(pattern)
if _is_match(modules, node, pattern) and node.name not in match_map:
matched_node_pattern: list[Node] = []
record_match(pattern, node, node, matched_node_pattern, match_map)

View File

@ -130,7 +130,7 @@ def _get_qspec_for_arg(
) -> Optional[QuantizationSpecBase]:
while _is_activation_post_process_node(arg, named_modules):
arg = arg.args[0] # type: ignore[assignment]
return input_qspec_map.get(arg, None)
return input_qspec_map.get(arg)
def _create_obs_or_fq_from_qspec(

View File

@ -164,7 +164,7 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
}
prepack_op = prepack_ops.get(conv_op, None)
prepack_op = prepack_ops.get(conv_op)
assert prepack_op, f"Didn't find prepack op for {conv_op}"
return prepack_op

View File

@ -806,7 +806,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
unexpected_keys: list[str],
error_msgs: list[str],
):
version = local_metadata.get("version", None)
version = local_metadata.get("version")
if version is not None and version < 3:
local_state = ["min_vals", "max_vals"]
expected_min_name = "min_vals"

View File

@ -366,7 +366,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
if input_edge_obs_or_fq is None:
return new_arg
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg)
# the arg is observed as the output and is using the same instance as the input_edge
# we'll reuse the inserted observer/fake_quant
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
@ -497,11 +497,7 @@ def _maybe_insert_input_and_output_observers_for_node(
is_qat: bool,
model_device: Optional[torch.device] = None,
):
this_node_quantization_annotation = (
node.meta["quantization_annotation"]
if "quantization_annotation" in node.meta
else None
)
this_node_quantization_annotation = node.meta.get("quantization_annotation", None)
if this_node_quantization_annotation is None:
return

View File

@ -343,7 +343,7 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
"""Get the quantized operator corresponding to the float operator"""
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
assert quantized_op is not None, (
f"Operator {str(float_op)} does not have corresponding quantized op"
)

View File

@ -1357,11 +1357,7 @@ class X86InductorQuantizer(Quantizer):
def _annotate_output_share_observer_as_input(
self, input_node: Node, source_node: Node
):
source_node_quantization_annotation = (
source_node.meta[QUANT_ANNOTATION_KEY]
if QUANT_ANNOTATION_KEY in source_node.meta
else None
)
source_node_quantization_annotation = source_node.meta.get(QUANT_ANNOTATION_KEY)
if (
source_node_quantization_annotation
and source_node_quantization_annotation._is_output_of_quantized_pattern
@ -1400,10 +1396,8 @@ class X86InductorQuantizer(Quantizer):
return
# Get the quantization_annotation from getitem_node
maxpool_node_quantization_annotation = (
maxpool_node.meta[QUANT_ANNOTATION_KEY]
if QUANT_ANNOTATION_KEY in maxpool_node.meta
else None
maxpool_node_quantization_annotation = maxpool_node.meta.get(
QUANT_ANNOTATION_KEY
)
if (
maxpool_node_quantization_annotation

View File

@ -159,10 +159,7 @@ class EventList(list):
if p is not None:
assert p.fwd_thread is not None
t = (p.sequence_nr, p.fwd_thread)
if t in fwd_stacks:
evt.stack = fwd_stacks[t]
else:
evt.stack = []
evt.stack = fwd_stacks.get(t, [])
@property
def self_cpu_time_total(self):

View File

@ -214,7 +214,7 @@ def replicate(
state = replicate.state(module)
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh", None)
device_mesh = kwargs.get("device_mesh")
if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources

View File

@ -228,7 +228,7 @@ def replicate_impl(
# Place Replicate leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_replicate_cls.get(cls, None)
new_cls = cls_to_replicate_cls.get(cls)
if not new_cls:
dct = {"__deepcopy__": _unimplemented_deepcopy}
new_cls = type(f"Replicate{cls.__name__}", (ReplicateModule, cls), dct)

View File

@ -143,7 +143,7 @@ def register_tensor_creation_op(op):
takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
``torch.full_like``.
"""
creation_op = tensor_like_creation_op_map.get(op, None)
creation_op = tensor_like_creation_op_map.get(op)
if creation_op is None:
raise RuntimeError(f"Tensor creation {op} not supported!")
if kwargs is None:

View File

@ -678,7 +678,7 @@ class ShardedTensor(ShardedTensorBase):
copy_tensor = kwargs.get("copy", False)
non_blocking = kwargs.get("non_blocking", False)
memory_format = kwargs.get("memory_format", torch.preserve_format)
process_group = kwargs.get("process_group", None)
process_group = kwargs.get("process_group")
if (
not copy_tensor

View File

@ -605,7 +605,7 @@ def _distribute_tensors(
if pg is None:
pg = dist.distributed_c10d._get_default_group()
for key in keys:
_local_state = local_state_dict.get(key, None)
_local_state = local_state_dict.get(key)
if _local_state is None or torch.is_tensor(_local_state):
continue

View File

@ -127,7 +127,7 @@ def aggregate_stats(
}
for mod in model.modules():
if mod_mem_stat := mod_mem_stats.get(mod, None):
if mod_mem_stat := mod_mem_stats.get(mod):
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
sac_runtime = tradeoff_stats.sac_runtime
sac_memory = tradeoff_stats.sac_memory

View File

@ -711,7 +711,7 @@ class SACEstimator(TorchDispatchMode):
str(i in sac_stats.view_like_ops),
str(i in sac_stats.rand_ops),
str(i in sac_stats.saved_autograd_ops),
str(op_parent.get(i, None)),
str(op_parent.get(i)),
]
table_data.append(row)
# Define headers

View File

@ -107,7 +107,7 @@ def auto_quantize(func, qtype, quant_loss=None):
@functools.wraps(func)
def wrapper(*args, **kwargs):
group = kwargs.get("group", None)
group = kwargs.get("group")
async_op = kwargs.get("async_op", False)
if async_op is True:
raise RuntimeError("The async_op=True mode is not supported yet.")
@ -133,8 +133,8 @@ def auto_quantize(func, qtype, quant_loss=None):
elif func == dist.all_to_all_single:
tensors = args[0]
out_splits = kwargs.get("out_splits", None)
in_splits = kwargs.get("in_splits", None)
out_splits = kwargs.get("out_splits")
in_splits = kwargs.get("in_splits")
# Quantizing the input/output tensor
input_tensors = _quantize_tensor(args[1], qtype)
out_tensors = _quantize_tensor(tensors, qtype)

View File

@ -631,7 +631,7 @@ class _FileSystemWriter(StorageWriter):
def set_up_storage_writer(
self, is_coordinator: bool, *args: Any, **kwargs: Any
) -> None:
self.rank = kwargs.get("rank", None)
self.rank = kwargs.get("rank")
self.use_collectives = kwargs.get("use_collectives", True)
def _metadata_exists(self) -> bool:
@ -919,7 +919,7 @@ class FileSystemReader(StorageReader):
# Implementing the abstract function in StorageReader
def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
rank = kwargs.get("rank", None)
rank = kwargs.get("rank")
path = self._get_metadata_path(rank)
with self.fs.create_stream(path, "rb") as metadata_file:
metadata = pickle.load(metadata_file)
@ -934,7 +934,7 @@ class FileSystemReader(StorageReader):
self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
) -> None:
self.storage_data = metadata.storage_data
self.rank = kwargs.get("rank", None)
self.rank = kwargs.get("rank")
self.use_collectives = kwargs.get("use_collectives", True)
assert self.storage_data is not None

View File

@ -31,11 +31,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
msg_dict = {}
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly
storage_writer = kwargs.get("storage_writer", None)
storage_reader = kwargs.get("storage_reader", None)
planner = kwargs.get("planner", None)
storage_writer = kwargs.get("storage_writer")
storage_reader = kwargs.get("storage_reader")
planner = kwargs.get("planner")
checkpoint_id = kwargs.get("checkpoint_id", None)
checkpoint_id = kwargs.get("checkpoint_id")
if not checkpoint_id and (serializer := storage_writer or storage_reader):
# pyrefly: ignore # unbound-name
checkpoint_id = getattr(serializer, "checkpoint_id", None)

View File

@ -307,7 +307,7 @@ def _verify_options(
continue
fqns = _get_fqns(model, name)
fqn = fqn_param_mapping.get(param, None)
fqn = fqn_param_mapping.get(param)
if fqn is not None:
cast(set[str], fqn_param_mapping[param]).update(fqns)
shared_params_mapping[param] = fqn_param_mapping[param]

View File

@ -5081,7 +5081,7 @@ def _is_safe_to_split() -> bool:
users must be aware that a pg is only splittable after the first collective is
issued.
"""
return False if _get_default_group().bound_device_id is None else True
return _get_default_group().bound_device_id is not None
@_time_logger

View File

@ -88,10 +88,7 @@ def configure(handler: MetricHandler, group: Optional[str] = None):
def getStream(group: str):
if group in _metrics_map:
handler = _metrics_map[group]
else:
handler = _default_metrics_handler
handler = _metrics_map.get(group, _default_metrics_handler)
return MetricStream(group, handler)

View File

@ -241,7 +241,7 @@ def fully_shard(
# Place FSDP leftmost for highest priority in the method resolution order
for module in modules:
cls = module.__class__
new_cls = cls_to_fsdp_cls.get(cls, None)
new_cls = cls_to_fsdp_cls.get(cls)
if not new_cls:
dct = {"__deepcopy__": _unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)

View File

@ -1270,7 +1270,7 @@ def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool:
(which usually are FQNs) versus integers (which usually refer to param_ids
from a vanilla torch.optim.Optimizer).
"""
state = optim_state_dict.get("state", None)
state = optim_state_dict.get("state")
if not state:
# If we cannot find a state, assume it is not NamedOptimizer as
# NamedOptimizer has eager initialization.
@ -1718,7 +1718,7 @@ def _convert_state_with_orig_params(
# across ranks
for optim_state_key in all_optim_state_keys:
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
optim_state_key, None
optim_state_key
)
if param_key is None and not optim_state_key.is_fsdp_managed:
@ -1726,7 +1726,7 @@ def _convert_state_with_orig_params(
if optim_state_key.is_fsdp_managed:
fqn = optim_state_key.unflat_param_names[0]
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn)
if fsdp_param_info is None:
# This can happen if the not all FSDP instances have all the
# parameters. This can happen with FSDP + some MPMD style
@ -1804,7 +1804,7 @@ def _convert_state_with_flat_params(
# across ranks
for optim_state_key in all_optim_state_keys:
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
optim_state_key, None
optim_state_key
)
assert param_key is not None, (

View File

@ -52,7 +52,7 @@ class _ScriptLocalOptimizer(nn.Module):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# apply functional optimizer step with a list of gradients
grads: list[Optional[Tensor]] = [
all_local_grads[p] if p in all_local_grads else None
all_local_grads[p] if p in all_local_grads else None # noqa: SIM401
for p in self._local_params
]

View File

@ -189,7 +189,7 @@ def _insert_stage_symbolic_backward(
output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
if node in tuples:
stage_output = tuples[node]
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
outputs_with_grads_idxs = [
i for i, n in enumerate(tuples[node]) if n in live_nodes
]

View File

@ -114,7 +114,7 @@ def get_param_groups(
"intermediates": intersected,
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
existing = param_groups.get(input_node)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(

View File

@ -326,8 +326,7 @@ def _insert_copy_for_mutations(
return_nodes_to_copy[return_node] = copy_node
output_args = tuple(
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
for node in user_output_nodes
return_nodes_to_copy.get(node, node) for node in user_output_nodes
)
with gm.graph.inserting_before(output_node):
# Only return user outputs

View File

@ -46,7 +46,7 @@ class NormalizeArgs(Transformer):
def get_type(arg):
if isinstance(arg, fx.Node):
return n.meta["type"] if "type" in n.meta else None
return n.meta.get("type")
return type(arg)
arg_types = map_aggregate(n.args, get_type)

View File

@ -4414,7 +4414,7 @@ class ShapeEnv:
size = []
for i, val in enumerate(tensor_size):
sym = self.create_symbol(
val if i not in hint_overrides else hint_overrides[i],
hint_overrides.get(i, val),
TensorPropertySource(source, TensorProperty.SIZE, i),
dynamic_dims[i],
constraint_dims[i],
@ -4615,7 +4615,7 @@ class ShapeEnv:
sym_sizes = [
self.create_symintnode(
sym,
hint=hint if i not in hint_overrides else hint_overrides[i],
hint=hint_overrides.get(i, hint),
source=TensorPropertySource(source, TensorProperty.SIZE, i),
)
for i, (sym, hint) in enumerate(zip(size, ex_size))

View File

@ -930,7 +930,7 @@ class ParameterDict(Module):
key (str): key to get from the ParameterDict
default (Parameter, optional): value to return if key not present
"""
return self[key] if key in self else default
return self[key] if key in self else default # noqa: SIM401
def fromkeys(
self, keys: Iterable[str], default: Optional[Any] = None

View File

@ -1761,11 +1761,7 @@ class Module:
if recording_scopes:
# type ignore was added because at this point one knows that
# torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
name = (
torch.jit._trace._trace_module_map[self] # type: ignore[index]
if self in torch.jit._trace._trace_module_map # type: ignore[operator]
else None
) # noqa: B950
name = torch.jit._trace._trace_module_map.get(self, None) # type: ignore[operator, union-attr]
if name:
tracing_state.push_scope(name)
else:

View File

@ -218,7 +218,7 @@ def _dump_DDP_relevant_env_vars():
]
formatted_output = ""
for var in relevant_env_vars:
value = os.environ[var] if var in os.environ else "N/A"
value = os.environ.get(var, "N/A")
formatted_output += f"env:{var}={value}\n"
print(formatted_output)

View File

@ -783,8 +783,8 @@ class Optimizer:
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
fused = pg.get("fused", False)
capturable = pg.get("capturable", False)
break
if key == "step":
if capturable or fused:

View File

@ -390,8 +390,8 @@ class DeviceTypeTestBase(TestCase):
return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
def _apply_precision_override_for_test(self, test, param_kwargs):
dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
dtype = param_kwargs.get("dtype")
dtype = param_kwargs.get("dtypes", dtype)
if dtype:
self.precision = self._get_precision_override(test, dtype)
self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)

View File

@ -1915,7 +1915,7 @@ def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
# The scalar we are passing to new_full must be the same dtype
# as the one of the resulting tensor
use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
use_dtype = sample.kwargs.get('dtype', dtype)
yield SampleInput(
sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)

View File

@ -725,7 +725,7 @@ class DistributedTest:
lines = out.getvalue().splitlines()
def format_line(var):
return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
return f"env:{var}={os.environ.get(var, 'N/A')}"
# Check relevant env vars
vars = [
@ -6212,7 +6212,7 @@ class DistributedTest:
)
def test_ddp_logging_data_cpu(self):
def parse_env(var):
return os.environ[var] if var in os.environ else "N/A"
return os.environ.get(var, "N/A")
dist.set_debug_level(dist.DebugLevel.INFO)
_, group_id, _ = self._init_global_test()

View File

@ -21,7 +21,7 @@ INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
return _MIRROR_REL_OP.get(type, None)
return _MIRROR_REL_OP.get(type)
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.

View File

@ -277,7 +277,7 @@ def create_graph(objects, *, context=None, filter=None):
references = annotated_references(obj)
for referrent in gc.get_referents(obj):
rid = id(referrent)
tidx = id_to_node.get(rid, None)
tidx = id_to_node.get(rid)
if tidx is None:
continue
labels = references.get(rid, ["?"])