[1/N] Use "is" in python type comparison (#165037)

It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
This commit is contained in:
Yuanyuan Chen
2025-10-10 12:36:46 +00:00
committed by PyTorch MergeBot
parent 960b0d5f0d
commit 70925bdf82
37 changed files with 51 additions and 51 deletions

View File

@ -2278,7 +2278,7 @@ class GuardBuilder(GuardBuilderBase):
# don't support this in serialization because it uses unsupported FUNCTION_MATCH # don't support this in serialization because it uses unsupported FUNCTION_MATCH
val = self.get(guard.name) val = self.get(guard.name)
# Strictly only want user-defined functions # Strictly only want user-defined functions
if type(val) == types.FunctionType and hasattr(val, "__code__"): if type(val) is types.FunctionType and hasattr(val, "__code__"):
self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type] self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type] self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
else: else:

View File

@ -197,7 +197,7 @@ class ConstDictVariable(VariableTracker):
@staticmethod @staticmethod
def _eq_impl(a, b): def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here # TODO: Put this in utils and share it between variables/builtin.py and here
if type(a) != type(b): if type(a) is not type(b):
return False return False
elif isinstance(a, tuple): elif isinstance(a, tuple):
Hashable = ConstDictVariable._HashableTracker Hashable = ConstDictVariable._HashableTracker

View File

@ -3532,7 +3532,7 @@ class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable):
from torch.distributed.tensor.experimental._func_map import _local_map_wrapped from torch.distributed.tensor.experimental._func_map import _local_map_wrapped
# check is important to avoid subclass dispatch # check is important to avoid subclass dispatch
if type(value) != type(_local_map_wrapped): if type(value) is not type(_local_map_wrapped):
return False return False
return value == _local_map_wrapped and cls._enabled return value == _local_map_wrapped and cls._enabled

View File

@ -200,7 +200,7 @@ class BaseListVariable(VariableTracker):
if kwargs or len(args) != 1: if kwargs or len(args) != 1:
raise_args_mismatch(tx, name) raise_args_mismatch(tx, name)
if type(self) != type(args[0]): if type(self) is not type(args[0]):
tp_name = self.python_type_name() tp_name = self.python_type_name()
other = args[0].python_type_name() other = args[0].python_type_name()
msg = ConstantVariable.create( msg = ConstantVariable.create(

View File

@ -464,7 +464,7 @@ def _check_input_constraints_for_graph(
) )
elif isinstance(node_val, (int, float, str)): elif isinstance(node_val, (int, float, str)):
if type(arg) != type(node_val) or arg != node_val: if type(arg) is not type(node_val) or arg != node_val:
raise RuntimeError( raise RuntimeError(
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
) )

View File

@ -361,7 +361,7 @@ def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_
input_meta += get_input_meta(args[1]) input_meta += get_input_meta(args[1])
return input_meta return input_meta
for arg in args: for arg in args:
if type(arg) == int or type(arg) == float: if type(arg) is int or type(arg) is float:
input_meta.append((type(arg),)) input_meta.append((type(arg),))
else: else:
input_meta.append( input_meta.append(

View File

@ -802,7 +802,7 @@ def _insert_fused_matmul_reduce_scatter(
scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern
output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern
) -> torch.fx.Node: ) -> torch.fx.Node:
if type(matmul) == _Matmul: if type(matmul) is _Matmul:
return graph.call_function( return graph.call_function(
torch.ops.symm_mem.fused_matmul_reduce_scatter.default, torch.ops.symm_mem.fused_matmul_reduce_scatter.default,
args=( args=(
@ -813,7 +813,7 @@ def _insert_fused_matmul_reduce_scatter(
group_name, group_name,
), ),
) )
elif type(matmul) == _ScaledMatmul: elif type(matmul) is _ScaledMatmul:
return graph.call_function( return graph.call_function(
torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default,
args=( args=(

View File

@ -98,7 +98,7 @@ class FakeTensorUpdater:
return statically_known_true(sym_eq(new, old)) return statically_known_true(sym_eq(new, old))
def is_fake_tensor_same(new, old, *, node): def is_fake_tensor_same(new, old, *, node):
if type(new) != type(old): if type(new) is not type(old):
return False return False
if isinstance(new, (list, tuple)): if isinstance(new, (list, tuple)):
if len(new) != len(old): if len(new) != len(old):

View File

@ -122,7 +122,7 @@ class SchemaCheckMode(TorchDispatchMode):
def parse_metadata(e): def parse_metadata(e):
if isinstance(e, torch.Tensor): if isinstance(e, torch.Tensor):
if type(e) != torch.Tensor: if type(e) is not torch.Tensor:
try: try:
current = e.elem current = e.elem
return ( return (

View File

@ -37,7 +37,7 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
if isinstance(dtype, str): if isinstance(dtype, str):
dtype = _import_dotted_name(dtype) dtype = _import_dotted_name(dtype)
if dtype == type(self): if dtype is type(self):
return self return self
if self.is_sparse: if self.is_sparse:
if not dtype.is_sparse: if not dtype.is_sparse:

View File

@ -80,7 +80,7 @@ class ConvReLU1d(nnq.Conv1d):
@classmethod @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.weight,
@ -161,7 +161,7 @@ class ConvReLU2d(nnq.Conv2d):
@classmethod @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.weight,
@ -244,7 +244,7 @@ class ConvReLU3d(nnq.Conv3d):
@classmethod @classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override] def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.weight,

View File

@ -117,7 +117,7 @@ class Linear(nnq.Linear):
+ str([float_mod.__name__ for float_mod in float_modules]) + str([float_mod.__name__ for float_mod in float_modules])
) )
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU: if type(mod) is nni.LinearReLU:
mod = mod[0] mod = mod[0]
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None: if mod.qconfig is not None and mod.qconfig.weight is not None:

View File

@ -1064,15 +1064,15 @@ class RNNCellBase(torch.nn.Module):
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell] qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
if type(mod) == torch.nn.LSTMCell: if type(mod) is torch.nn.LSTMCell:
qRNNCellBase = LSTMCell( qRNNCellBase = LSTMCell(
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
) )
elif type(mod) == torch.nn.GRUCell: elif type(mod) is torch.nn.GRUCell:
qRNNCellBase = GRUCell( qRNNCellBase = GRUCell(
mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype
) )
elif type(mod) == torch.nn.RNNCell: elif type(mod) is torch.nn.RNNCell:
qRNNCellBase = RNNCell( qRNNCellBase = RNNCell(
mod.input_size, mod.input_size,
mod.hidden_size, mod.hidden_size,

View File

@ -20,7 +20,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
@staticmethod @staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False): def from_float(cls, mod, use_precomputed_fake_quant=False):
activation_post_process = mod.activation_post_process activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_BN_RELU_MODULE: if type(mod) is cls._NNI_BN_RELU_MODULE:
mod = mod[0] mod = mod[0]
scale, zero_point = activation_post_process.calculate_qparams() scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps) new_mod = cls(mod.num_features, mod.eps)

View File

@ -280,7 +280,7 @@ class _ConvNd(WeightedQuantizedModule):
if hasattr(mod, "weight_fake_quant"): if hasattr(mod, "weight_fake_quant"):
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
# ".from_float only works for " + cls.__QAT_MODULE.__name__ # ".from_float only works for " + cls.__QAT_MODULE.__name__
if type(mod) == cls._NNIQAT_CONV_BN_MODULE: if type(mod) is cls._NNIQAT_CONV_BN_MODULE:
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.weight,
mod.bias, mod.bias,

View File

@ -149,7 +149,7 @@ class Linear(torch.nn.Module):
# TODO: Need to add options to qconfig to avoid the calibration. # TODO: Need to add options to qconfig to avoid the calibration.
# TODO: Add calibration for the sparsity # TODO: Add calibration for the sparsity
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU: if type(mod) is nni.LinearReLU:
mod = mod[0] mod = mod[0]
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
if mod.qconfig is not None and mod.qconfig.weight is not None: if mod.qconfig is not None and mod.qconfig.weight is not None:

View File

@ -268,7 +268,7 @@ def extract_weight_from_node(
mod = getattr_from_fqn(gm, node.target) mod = getattr_from_fqn(gm, node.target)
module_mapping = op_to_type_to_weight_extraction_fn["call_module"] module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
for target_mod_type, weight_extraction_fn in module_mapping.items(): for target_mod_type, weight_extraction_fn in module_mapping.items():
if type(mod) == target_mod_type: if type(mod) is target_mod_type:
weight = weight_extraction_fn(mod) weight = weight_extraction_fn(mod)
return { return {
"type": res_type, "type": res_type,

View File

@ -792,7 +792,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
for i in range(1, len(path)): for i in range(1, len(path)):
prev_key = path[i - 1] prev_key = path[i - 1]
key = path[i] key = path[i]
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else [] def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) is str else []
if isinstance(cur_container, Mapping): if isinstance(cur_container, Mapping):
cur_container = cast( cur_container = cast(
@ -806,7 +806,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
cur_container = cur_container[prev_key] cur_container = cur_container[prev_key]
key = path[-1] key = path[-1]
if type(key) == int: if type(key) is int:
extend_list(cast(list[Any], cur_container), key) extend_list(cast(list[Any], cur_container), key)
cur_container[key] = value cur_container[key] = value

View File

@ -121,7 +121,7 @@ def set_element(
for i in range(1, len(path)): for i in range(1, len(path)):
prev_key = path[i - 1] prev_key = path[i - 1]
key = path[i] key = path[i]
def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) def_val = cast(STATE_DICT_ITEM, {} if type(key) is str else [])
if isinstance(cur_container, Mapping): if isinstance(cur_container, Mapping):
cur_container = cast( cur_container = cast(
@ -135,7 +135,7 @@ def set_element(
cur_container = cur_container[prev_key] cur_container = cur_container[prev_key]
key = path[-1] key = path[-1]
if type(key) == int: if type(key) is int:
extend_list(cast(list[STATE_DICT_ITEM], cur_container), key) extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
cur_container[key] = value cur_container[key] = value

View File

@ -2203,7 +2203,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
# alive until all works and hooks are done. The current implementation does the # alive until all works and hooks are done. The current implementation does the
# latter. Therefore, we explicitly call _wait_for_pending_works() here to wait # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
# for the pending hooks to finish. # for the pending hooks to finish.
if type(pg) == ProcessGroup and pg._has_hooks(): if type(pg) is ProcessGroup and pg._has_hooks():
pg._wait_for_pending_works() pg._wait_for_pending_works()
if group is None or group == GroupMember.WORLD: if group is None or group == GroupMember.WORLD:
@ -2783,7 +2783,7 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
key = "group_dst" if op.op == isend else "group_src" key = "group_dst" if op.op == isend else "group_src"
return {key: op.group_peer} return {key: op.group_peer}
if type(group) == ProcessGroup and group._get_backend(device).supports_coalescing: if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing:
# NCCL style coalescing # NCCL style coalescing
with _coalescing_manager(group, device, async_ops=True) as cm: with _coalescing_manager(group, device, async_ops=True) as cm:
for p2p_op in p2p_op_list: for p2p_op in p2p_op_list:

View File

@ -149,9 +149,9 @@ class EtcdStore(Store):
# In case of `str`, utf-8 encoding is assumed. # In case of `str`, utf-8 encoding is assumed.
# #
def _encode(self, value) -> str: def _encode(self, value) -> str:
if type(value) == bytes: if type(value) is bytes:
return b64encode(value).decode() return b64encode(value).decode()
elif type(value) == str: elif type(value) is str:
return b64encode(value.encode()).decode() return b64encode(value.encode()).decode()
raise ValueError("Value must be of type str or bytes") raise ValueError("Value must be of type str or bytes")
@ -160,9 +160,9 @@ class EtcdStore(Store):
# Return type is `bytes`, which is more convenient with the Store interface. # Return type is `bytes`, which is more convenient with the Store interface.
# #
def _decode(self, value) -> bytes: def _decode(self, value) -> bytes:
if type(value) == bytes: if type(value) is bytes:
return b64decode(value) return b64decode(value)
elif type(value) == str: elif type(value) is str:
return b64decode(value.encode()) return b64decode(value.encode())
raise ValueError("Value must be of type str or bytes") raise ValueError("Value must be of type str or bytes")

View File

@ -110,7 +110,7 @@ class DownSampling(nn.Module):
@torch.no_grad() @torch.no_grad()
def init_weights(m): def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.Linear: if type(m) is nn.Conv2d or type(m) is nn.Linear:
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)

View File

@ -280,7 +280,7 @@ def _kl_exponential_exponential(p, q):
@register_kl(ExponentialFamily, ExponentialFamily) @register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q): def _kl_expfamily_expfamily(p, q):
if not type(p) == type(q): if type(p) is not type(q):
raise NotImplementedError( raise NotImplementedError(
"The cross KL-divergence between different exponential families cannot \ "The cross KL-divergence between different exponential families cannot \
be computed using Bregman divergences" be computed using Bregman divergences"

View File

@ -872,7 +872,7 @@ class AdditionalInputs:
] ]
def _mark_dynamism(v, *other_vs): def _mark_dynamism(v, *other_vs):
if not all(type(v) == type(other) for other in other_vs): if not all(type(v) is type(other) for other in other_vs):
raise ValueError( raise ValueError(
"The following inputs were found to have differing types, " "The following inputs were found to have differing types, "
f"so they cannot be marked as dynamic: {(v,) + other_vs}." f"so they cannot be marked as dynamic: {(v,) + other_vs}."

View File

@ -927,7 +927,7 @@ class Tracer(TracerBase):
return out return out
# Union[int, bool] == bool in Python <= 3.6 # Union[int, bool] == bool in Python <= 3.6
if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: if type(x) is bool or type(x) in base_types and type(x) != torch.Tensor:
torch._assert( torch._assert(
out == x, out == x,
f"{name} has been specialized to have value {x} but got another value", f"{name} has been specialized to have value {x} but got another value",

View File

@ -112,7 +112,7 @@ def unify_object(u, v, s):
>>> unify_object(f, g, {}) >>> unify_object(f, g, {})
{~x: 2} {~x: 2}
""" """
if type(u) != type(v): if type(u) is not type(v):
return False return False
if hasattr(u, "__slots__"): if hasattr(u, "__slots__"):
return unify( return unify(

View File

@ -311,7 +311,7 @@ class ScriptMeta(type):
original_init(self, *args, **kwargs) original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls: if type(self) is cls:
def make_stubs(module): def make_stubs(module):
cls = type(module) cls = type(module)
@ -804,7 +804,7 @@ if _enabled:
@property @property
def original_name(self): def original_name(self):
if type(self) == str(self._c._type().name()): if type(self) is str(self._c._type().name()):
return "" return ""
return str(self._c._type().name()) return str(self._c._type().name())

View File

@ -651,7 +651,7 @@ def analyze_ts_result_with_export_result(export, trace):
# mkldnn is not supported for torch.allclose # mkldnn is not supported for torch.allclose
if orig.layout == torch._mkldnn: # type: ignore[attr-defined] if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
return True return True
if type(orig) != type(loaded): if type(orig) is not type(loaded):
return False return False
if isinstance(orig, torch._subclasses.FakeTensor): if isinstance(orig, torch._subclasses.FakeTensor):

View File

@ -198,7 +198,7 @@ class MaskedTensor(torch.Tensor):
def _validate_members(self): def _validate_members(self):
data = self._masked_data data = self._masked_data
mask = self.get_mask() mask = self.get_mask()
if type(data) != type(mask): if type(data) is not type(mask):
raise TypeError( raise TypeError(
f"data and mask must have the same type. Got {type(data)} and {type(mask)}" f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
) )

View File

@ -1049,7 +1049,7 @@ class Module:
>>> @torch.no_grad() >>> @torch.no_grad()
>>> def init_weights(m): >>> def init_weights(m):
>>> print(m) >>> print(m)
>>> if type(m) == nn.Linear: >>> if type(m) is nn.Linear:
>>> m.weight.fill_(1.0) >>> m.weight.fill_(1.0)
>>> print(m.weight) >>> print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

View File

@ -10,7 +10,7 @@ def is_from_package(obj: Any) -> bool:
Note: packaged objects from externed modules will return ``False``. Note: packaged objects from externed modules will return ``False``.
""" """
if type(obj) == ModuleType: if type(obj) is ModuleType:
return is_mangled(obj.__name__) return is_mangled(obj.__name__)
else: else:
return is_mangled(type(obj).__module__) return is_mangled(type(obj).__module__)

View File

@ -891,7 +891,7 @@ class TypedStorage:
def _new_wrapped_storage(self, untyped_storage) -> Self: def _new_wrapped_storage(self, untyped_storage) -> Self:
assert type(untyped_storage) == torch.UntypedStorage assert type(untyped_storage) == torch.UntypedStorage
if type(self) == TypedStorage: if type(self) is TypedStorage:
return cast( return cast(
Self, Self,
TypedStorage( TypedStorage(
@ -913,7 +913,7 @@ class TypedStorage:
return 0 return 0
else: else:
if type(idx) != int: if type(idx) is not int:
raise TypeError(f"can't index a {type(self)} with {type(idx)}") raise TypeError(f"can't index a {type(self)} with {type(idx)}")
if is_stop: if is_stop:
if (idx > self._size()) or (idx < -self._size()): if (idx > self._size()) or (idx < -self._size()):
@ -1513,7 +1513,7 @@ class _LegacyStorageMeta(type):
dtype: torch.dtype dtype: torch.dtype
def __instancecheck__(cls, instance): def __instancecheck__(cls, instance):
if type(instance) == TypedStorage: if type(instance) is TypedStorage:
cls_device = _get_device_from_module(cls.__module__) cls_device = _get_device_from_module(cls.__module__)
return (cls_device == instance.device.type) and ( return (cls_device == instance.device.type) and (
cls.dtype == instance.dtype cls.dtype == instance.dtype

View File

@ -1432,7 +1432,7 @@ class MultiThreadedTestCase(TestCase):
logger.error("Caught exception: \n%s exiting thread %s", msg, rank) logger.error("Caught exception: \n%s exiting thread %s", msg, rank)
error_msg += f"Thread {rank} exited with exception:\n{msg}\n" error_msg += f"Thread {rank} exited with exception:\n{msg}\n"
elif isinstance(exc, SystemExit): elif isinstance(exc, SystemExit):
if type(exc.code) == int and skip_code < 0: if type(exc.code) is int and skip_code < 0:
skip_code = exc.code skip_code = exc.code
# check exceptions # check exceptions

View File

@ -1247,7 +1247,7 @@ class QuantizationTestCase(TestCase):
} }
""" """
# TODO: make img_data a single example instead of a list # TODO: make img_data a single example instead of a list
if type(inputs) == list: if type(inputs) is list:
inputs = inputs[0] inputs = inputs[0]
if quant_type == QuantType.QAT: if quant_type == QuantType.QAT:

View File

@ -365,7 +365,7 @@ def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs)
for mask in _generate_masked_op_mask( for mask in _generate_masked_op_mask(
sample_input.input.shape, device, **kwargs sample_input.input.shape, device, **kwargs
): ):
if type(mask) != torch.Tensor: if type(mask) is not torch.Tensor:
continue continue
sample_input_args, sample_input_kwargs = ( sample_input_args, sample_input_kwargs = (
sample_input.args, sample_input.args,

View File

@ -70,7 +70,7 @@ class WrapperSubclass(torch.Tensor):
def __coerce_same_metadata_as_tangent__( def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[type] = None self, expected_metadata: Any, expected_type: Optional[type] = None
): ):
if expected_type == type(self.a): if expected_type is type(self.a):
return self.a return self.a
elif expected_type is TwoTensor: elif expected_type is TwoTensor:
return TwoTensor(self.a, self.a.clone()) return TwoTensor(self.a, self.a.clone())

View File

@ -187,7 +187,7 @@ class GraphPy:
) )
for key, node in self.nodes_io.items(): for key, node in self.nodes_io.items():
if type(node) == NodeBase: if type(node) is NodeBase:
# pyrefly: ignore # unsupported-operation # pyrefly: ignore # unsupported-operation
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if hasattr(node, "input_or_output"): if hasattr(node, "input_or_output"):