mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408 Approved by: https://github.com/justinchuby, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
f38a2ea0d4
commit
292af3cc89
@ -119,7 +119,7 @@ def test(args, model, test_loader, device):
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
print(f"\tTest set:Loss: {np.mean(losses):.6f} Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
|
@ -185,7 +185,7 @@ def test(args, model, test_loader, device):
|
||||
|
||||
top1_avg = np.mean(top1_acc)
|
||||
|
||||
print(f"\tTest set:" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {top1_avg :.6f} ")
|
||||
print(f"\tTest set:Loss: {np.mean(losses):.6f} Acc@1: {top1_avg :.6f} ")
|
||||
return np.mean(top1_acc)
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ class _GradAccConfig:
|
||||
def __repr__(self) -> str:
|
||||
# Override to remove any spaces in the string to appease the internal
|
||||
# build's test name parser
|
||||
return f"(use_no_sync={self.use_no_sync}," f"num_iters={self.num_iters})"
|
||||
return f"(use_no_sync={self.use_no_sync},num_iters={self.num_iters})"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -36,7 +36,7 @@ class TestException(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"This op may not exist or may not be currently " "supported in TorchScript",
|
||||
"This op may not exist or may not be currently supported in TorchScript",
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
|
@ -1702,7 +1702,7 @@ class TestTracer(JitTestCase):
|
||||
def test_trace_checker_dot_data(self):
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.TracingCheckError,
|
||||
r"Tensor-valued Constant nodes differed in value " r"across invocations",
|
||||
r"Tensor-valued Constant nodes differed in value across invocations",
|
||||
):
|
||||
|
||||
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
|
||||
|
@ -754,7 +754,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_of_mixed"],
|
||||
"none of those types match the types of the" " given list elements",
|
||||
"none of those types match the types of the given list elements",
|
||||
)
|
||||
|
||||
self._assert_passes(
|
||||
@ -784,21 +784,21 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_literal_empty"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
"Expected an Union type annotation with an inner List type",
|
||||
)
|
||||
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_literal_of_tensor"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
"Expected an Union type annotation with an inner List type",
|
||||
)
|
||||
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_comprehension_of_tensor"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
"Expected an Union type annotation with an inner List type",
|
||||
)
|
||||
|
||||
"""
|
||||
@ -890,7 +890,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with an " "inner Dict type",
|
||||
"Expected an Union type annotation with an inner Dict type",
|
||||
)
|
||||
|
||||
self._assert_passes(
|
||||
@ -974,14 +974,14 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with " "an inner Dict type",
|
||||
"Expected an Union type annotation with an inner Dict type",
|
||||
)
|
||||
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["dict_literal_of_str_tensor"],
|
||||
"Expected an Union type annotation with " "an inner Dict type",
|
||||
"Expected an Union type annotation with an inner Dict type",
|
||||
)
|
||||
|
||||
# See above--string frontend does not support tuple unpacking
|
||||
|
@ -762,7 +762,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"List[str] | List[torch.Tensor]",
|
||||
lhs["list_literal_of_mixed"],
|
||||
"none of those types match the types of the" " given list elements",
|
||||
"none of those types match the types of the given list elements",
|
||||
)
|
||||
|
||||
self._assert_passes(
|
||||
@ -790,21 +790,21 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"int | torch.Tensor",
|
||||
lhs["list_literal_empty"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
"Expected an Union type annotation with an inner List type",
|
||||
)
|
||||
|
||||
self._assert_raises(
|
||||
template,
|
||||
"int | torch.Tensor",
|
||||
lhs["list_literal_of_tensor"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
"Expected an Union type annotation with an inner List type",
|
||||
)
|
||||
|
||||
self._assert_raises(
|
||||
template,
|
||||
"int | torch.Tensor",
|
||||
lhs["list_comprehension_of_tensor"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
"Expected an Union type annotation with an inner List type",
|
||||
)
|
||||
|
||||
"""
|
||||
@ -894,7 +894,7 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"List[str] | List[torch.Tensor]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with an " "inner Dict type",
|
||||
"Expected an Union type annotation with an inner Dict type",
|
||||
)
|
||||
|
||||
self._assert_passes(
|
||||
@ -978,14 +978,14 @@ class TestUnion(JitTestCase):
|
||||
template,
|
||||
"int | torch.Tensor",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with " "an inner Dict type",
|
||||
"Expected an Union type annotation with an inner Dict type",
|
||||
)
|
||||
|
||||
self._assert_raises(
|
||||
template,
|
||||
"int | torch.Tensor",
|
||||
lhs["dict_literal_of_str_tensor"],
|
||||
"Expected an Union type annotation with " "an inner Dict type",
|
||||
"Expected an Union type annotation with an inner Dict type",
|
||||
)
|
||||
|
||||
# See above--string frontend does not support tuple unpacking
|
||||
|
@ -253,7 +253,7 @@ class TestFindMismatch(pytorch_test_common.ExportTestCase):
|
||||
leaf_info.pretty_print_mismatch(graph=True)
|
||||
self.assertRegex(
|
||||
f.getvalue(),
|
||||
r"(.|\n)*" r"aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
|
||||
r"(.|\n)*aten::relu.*/torch/nn/functional.py:[0-9]+(.|\n)*",
|
||||
)
|
||||
|
||||
def test_find_all_mismatch_operators(self):
|
||||
|
@ -8091,7 +8091,7 @@ for shape in [(1,), ()]:
|
||||
view_a = a.unbind()[0]
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"This view is the output of a function that returns " "multiple views.",
|
||||
"This view is the output of a function that returns multiple views.",
|
||||
):
|
||||
view_a.copy_(b)
|
||||
|
||||
|
@ -23,7 +23,7 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
@ -231,7 +231,7 @@ class TestXNNPACKOps(TestCase):
|
||||
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
@ -753,7 +753,7 @@ class TestXNNPACKSerDes(TestCase):
|
||||
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
@ -1241,7 +1241,7 @@ class TestXNNPACKRewritePass(TestCase):
|
||||
|
||||
@unittest.skipUnless(
|
||||
torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
|
||||
" XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.",
|
||||
)
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_TSAN,
|
||||
|
@ -3029,7 +3029,7 @@ class TestBroadcast(TestCase):
|
||||
def test_shape_mismatch_error_message(self):
|
||||
with assert_raises(
|
||||
ValueError,
|
||||
match=r"arg 0 with shape \(1, 3\) and " r"arg 2 with shape \(2,\)",
|
||||
match=r"arg 0 with shape \(1, 3\) and arg 2 with shape \(2,\)",
|
||||
):
|
||||
np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7])
|
||||
|
||||
|
@ -250,7 +250,7 @@ class TestPower(TestCase):
|
||||
a = t1(3)
|
||||
b = t2(2)
|
||||
result = a**b
|
||||
msg = f"error with {t1!r} and {t2!r}:" f"got {result!r}, expected {9!r}"
|
||||
msg = f"error with {t1!r} and {t2!r}:got {result!r}, expected {9!r}"
|
||||
if np.issubdtype(np.dtype(result), np.integer):
|
||||
assert_(result == 9, msg)
|
||||
else:
|
||||
|
@ -105,7 +105,7 @@ def parse_args() -> argparse.Namespace:
|
||||
"--destination",
|
||||
default="dist/",
|
||||
type=str,
|
||||
help=("Destination to put the compailed binaries" ""),
|
||||
help=("Destination to put the compailed binaries"),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -1482,7 +1482,7 @@ def non_single_tensor_return_unsupported(api, ret):
|
||||
|
||||
if not isinstance(ret, TensorVariable):
|
||||
raise Unsupported(
|
||||
f"{api} over function that returns something " f"other than one Tensor"
|
||||
f"{api} over function that returns something other than one Tensor"
|
||||
)
|
||||
|
||||
|
||||
|
@ -73,9 +73,7 @@ def _create_differentiable(inps, level=None):
|
||||
if isinstance(x, torch.Tensor):
|
||||
with enable_inplace_requires_grad(True):
|
||||
return _set_tensor_requires_grad(x)
|
||||
raise ValueError(
|
||||
f"Thing passed to transform API must be Tensor, " f"got {type(x)}"
|
||||
)
|
||||
raise ValueError(f"Thing passed to transform API must be Tensor, got {type(x)}")
|
||||
|
||||
return tree_map(create_differentiable, inps)
|
||||
|
||||
@ -954,7 +952,7 @@ def assert_non_empty_list_of_tensors(
|
||||
if isinstance(out, torch.Tensor):
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"{api}: Expected {argname} to only contain Tensors, got " f"{type(out)}"
|
||||
f"{api}: Expected {argname} to only contain Tensors, got {type(out)}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -885,7 +885,7 @@ def solve_min_cut(
|
||||
import networkx as nx
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Need networkx installed to perform smart recomputation " "heuristics"
|
||||
"Need networkx installed to perform smart recomputation heuristics"
|
||||
) from e
|
||||
|
||||
def is_materialized_backwards(node):
|
||||
|
@ -66,9 +66,7 @@ def infer_schema(
|
||||
sig = inspect.signature(prototype_function)
|
||||
|
||||
def error_fn(what):
|
||||
raise ValueError(
|
||||
f"infer_schema(func): {what} " f"Got func with signature {sig})"
|
||||
)
|
||||
raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
|
||||
|
||||
def convert_type_string(annotation_type: str):
|
||||
try:
|
||||
|
@ -689,9 +689,7 @@ def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
|
||||
f")"
|
||||
)
|
||||
if torch._C._functorch.is_gradtrackingtensor(tensor):
|
||||
return (
|
||||
f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
|
||||
)
|
||||
return f"GradTrackingTensor(lvl={level}, value=\n{indented_value_repr}\n)"
|
||||
if torch._C._functorch.is_functionaltensor(tensor):
|
||||
return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
|
||||
|
||||
|
@ -69,7 +69,7 @@ def _validate_output_tensor_for_gather(
|
||||
)
|
||||
elif dst_tensor:
|
||||
raise ValueError(
|
||||
"Argument ``dst_tensor`` must NOT be specified " "on non-destination ranks."
|
||||
"Argument ``dst_tensor`` must NOT be specified on non-destination ranks."
|
||||
)
|
||||
|
||||
|
||||
|
@ -117,7 +117,7 @@ class ErrorHandler:
|
||||
rootcause_error_file, rootcause_error, error_code
|
||||
)
|
||||
logger.debug(
|
||||
"child error file (%s) contents:\n" "%s",
|
||||
"child error file (%s) contents:\n%s",
|
||||
rootcause_error_file,
|
||||
json.dumps(rootcause_error, indent=2),
|
||||
)
|
||||
|
@ -194,7 +194,7 @@ class TimerServer(abc.ABC):
|
||||
reaped_worker_ids = set()
|
||||
for worker_id, expired_timers in self.get_expired_timers(now).items():
|
||||
logger.info(
|
||||
"Reaping worker_id=[%s]." " Expired timers: %s",
|
||||
"Reaping worker_id=[%s]. Expired timers: %s",
|
||||
worker_id,
|
||||
self._get_scopes(expired_timers),
|
||||
)
|
||||
@ -212,7 +212,7 @@ class TimerServer(abc.ABC):
|
||||
|
||||
def start(self) -> None:
|
||||
logger.info(
|
||||
"Starting %s..." " max_interval=%s," " daemon=%s",
|
||||
"Starting %s... max_interval=%s, daemon=%s",
|
||||
type(self).__name__,
|
||||
self._max_interval,
|
||||
self._daemon,
|
||||
|
@ -50,7 +50,7 @@ def as_functional_optim(optim_cls: type, *args, **kwargs):
|
||||
functional_cls = functional_optim_map[optim_cls]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
|
||||
f"Optimizer {optim_cls} does not have a functional counterpart!"
|
||||
) from e
|
||||
|
||||
return _create_functional_optim(functional_cls, *args, **kwargs)
|
||||
|
@ -1517,14 +1517,14 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||
self._bucket_assignments_per_rank[self.global_rank]
|
||||
)
|
||||
logger.info(
|
||||
"rank %s with %s parameters " "across %s buckets",
|
||||
"rank %s with %s parameters across %s buckets",
|
||||
self.global_rank,
|
||||
local_numel,
|
||||
num_assigned_buckets,
|
||||
)
|
||||
if self.global_rank == 0:
|
||||
logger.info(
|
||||
"%s DDP " "buckets and " "%s bucket " "assignments",
|
||||
"%s DDP buckets and %s bucket assignments",
|
||||
len(self._overlap_info.params_per_bucket),
|
||||
self._overlap_info.num_bucket_assignments,
|
||||
)
|
||||
|
@ -188,9 +188,7 @@ def _validate_device_maps(
|
||||
for node in all_names:
|
||||
devices = all_devices[node]
|
||||
if len(set(devices)) != len(devices):
|
||||
raise ValueError(
|
||||
f"Node {node} has duplicated devices\n" f"devices = {devices}"
|
||||
)
|
||||
raise ValueError(f"Node {node} has duplicated devices\ndevices = {devices}")
|
||||
if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
|
||||
raise ValueError(
|
||||
f"Node {node} has devices with invalid indices\n"
|
||||
|
@ -421,7 +421,7 @@ class ShardingPropagator:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}"
|
||||
f"Sharding propagation failed on op {op_schema}.\nError: {e}"
|
||||
) from e
|
||||
|
||||
# step 2. if can't get output_spec from sharding
|
||||
|
@ -6397,7 +6397,7 @@ class ShapeEnv:
|
||||
maybe_extra_debug += "\nC++ stack trace:\n" + "".join(cpp_stack.format())
|
||||
elif is_debug:
|
||||
maybe_extra_debug += (
|
||||
"\nFor C++ stack trace, run with " "TORCHDYNAMO_EXTENDED_DEBUG_CPP=1"
|
||||
"\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1"
|
||||
)
|
||||
|
||||
return SLoc(floc, maybe_user_loc), maybe_extra_debug
|
||||
|
@ -176,7 +176,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise RuntimeError(
|
||||
f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}"
|
||||
f"ShapeProp error for: node={n.format_node()} with meta={n.meta}"
|
||||
) from e
|
||||
|
||||
found_tensor = False
|
||||
|
@ -200,7 +200,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"0D or 1D target tensor expected, " "multi-target not supported"
|
||||
"0D or 1D target tensor expected, multi-target not supported"
|
||||
)
|
||||
|
||||
is_batched = targ_dim > 0
|
||||
|
@ -706,13 +706,13 @@ class Module:
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(
|
||||
mod._get_name() + " has no " "attribute `" + item + "`"
|
||||
mod._get_name() + " has no attribute `" + item + "`"
|
||||
)
|
||||
|
||||
mod = getattr(mod, item)
|
||||
|
||||
if not isinstance(mod, torch.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
||||
raise AttributeError("`" + item + "` is not an nn.Module")
|
||||
|
||||
return mod
|
||||
|
||||
@ -829,7 +829,7 @@ class Module:
|
||||
param: torch.nn.Parameter = getattr(mod, param_name)
|
||||
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
raise AttributeError("`" + param_name + "` is not an " "nn.Parameter")
|
||||
raise AttributeError("`" + param_name + "` is not an nn.Parameter")
|
||||
|
||||
return param
|
||||
|
||||
|
@ -313,7 +313,7 @@ class GroupNorm(Module):
|
||||
return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "{num_groups}, {num_channels}, eps={eps}, " "affine={affine}".format(
|
||||
return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format(
|
||||
**self.__dict__
|
||||
)
|
||||
|
||||
|
@ -42,7 +42,7 @@ def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
|
||||
for symbolic_func in default_and_custom_functions:
|
||||
overload_func = symbolic_func.onnx_function
|
||||
all_function_overload_names += f"ONNX Node: {overload_func.name}[opset={overload_func.opset};is_custom={symbolic_func.is_custom}]. \n" # noqa: B950
|
||||
return f"FX Node: {node.target}. \n" f"{all_function_overload_names}"
|
||||
return f"FX Node: {node.target}. \n{all_function_overload_names}"
|
||||
|
||||
|
||||
def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
|
||||
|
@ -244,7 +244,7 @@ class LBFGS(Optimizer):
|
||||
|
||||
if len(self.param_groups) != 1:
|
||||
raise ValueError(
|
||||
"LBFGS doesn't support per-parameter options " "(parameter groups)"
|
||||
"LBFGS doesn't support per-parameter options (parameter groups)"
|
||||
)
|
||||
|
||||
self._params = self.param_groups[0]["params"]
|
||||
|
@ -864,7 +864,7 @@ class Optimizer:
|
||||
|
||||
if len(groups) != len(saved_groups):
|
||||
raise ValueError(
|
||||
"loaded state dict has a different number of " "parameter groups"
|
||||
"loaded state dict has a different number of parameter groups"
|
||||
)
|
||||
param_lens = (len(g["params"]) for g in groups)
|
||||
saved_lens = (len(g["params"]) for g in saved_groups)
|
||||
|
@ -562,7 +562,7 @@ class PackageImporter(Importer):
|
||||
else:
|
||||
where = "``from list''"
|
||||
raise TypeError(
|
||||
f"Item in {where} must be str, " f"not {type(x).__name__}"
|
||||
f"Item in {where} must be str, not {type(x).__name__}"
|
||||
)
|
||||
elif x == "*":
|
||||
if not recursive and hasattr(module, "__all__"):
|
||||
|
@ -356,7 +356,7 @@ class DataLoader(Generic[_T_co]):
|
||||
self._dataset_kind = _DatasetKind.Map
|
||||
|
||||
if sampler is not None and shuffle:
|
||||
raise ValueError("sampler option is mutually exclusive with " "shuffle")
|
||||
raise ValueError("sampler option is mutually exclusive with shuffle")
|
||||
|
||||
if batch_sampler is not None:
|
||||
# auto_collation with custom batch_sampler
|
||||
|
@ -680,7 +680,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
|
||||
def __init__(self, *datapipes: IterDataPipe):
|
||||
if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
|
||||
raise TypeError(
|
||||
"All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`."
|
||||
"All inputs are required to be `IterDataPipe` for `ZipIterDataPipe`."
|
||||
)
|
||||
super().__init__()
|
||||
self.datapipes = datapipes # type: ignore[assignment]
|
||||
|
Reference in New Issue
Block a user