mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Enable SIM rules (#39806)
* Enable SIM rules Signed-off-by: cyy <cyyever@outlook.com> * More fixes Signed-off-by: cyy <cyyever@outlook.com> --------- Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@ -19,10 +19,13 @@ line-length = 119
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Never enforce `E501` (line length violations).
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||
# SIM300: Yoda condition detected
|
||||
# SIM212: Checks for if expressions that check against a negated condition.
|
||||
# SIM905: Consider using a list literal instead of `str.split`
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823", "SIM1", "SIM300", "SIM212", "SIM905"]
|
||||
# RUF013: Checks for the use of implicit Optional
|
||||
# in type annotations when the default parameter value is None.
|
||||
select = ["C", "E", "F", "I", "W", "RUF013", "UP006", "PERF102", "PLC1802", "PLC0208"]
|
||||
select = ["C", "E", "F", "I", "W", "RUF013", "UP006", "PERF102", "PLC1802", "PLC0208","SIM"]
|
||||
extend-safe-fixes = ["UP006"]
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
|
@ -901,7 +901,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
inputs = processor.apply_chat_template(
|
||||
processor_inputs,
|
||||
add_generation_prompt=True,
|
||||
tools=req.get("tools", None),
|
||||
tools=req.get("tools"),
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
tokenize=True,
|
||||
|
@ -183,7 +183,7 @@ def tf_default_data_collator(features: list[InputDataClass]) -> dict[str, Any]:
|
||||
if label_col_name is not None:
|
||||
if isinstance(first[label_col_name], tf.Tensor):
|
||||
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
|
||||
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
||||
elif isinstance(first[label_col_name], (np.ndarray, np.generic)):
|
||||
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
||||
elif isinstance(first[label_col_name], (tuple, list)):
|
||||
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
|
||||
|
@ -647,8 +647,8 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# If it's not defined, it means the model uses the new general mask API
|
||||
if causal_mask_creation_function is None: # can't be found
|
||||
token_type_ids = model_inputs.get("token_type_ids", None)
|
||||
position_ids = model_inputs.get(position_ids_key, None)
|
||||
token_type_ids = model_inputs.get("token_type_ids")
|
||||
position_ids = model_inputs.get(position_ids_key)
|
||||
# Some models may overwrite the general one
|
||||
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
|
||||
attention_mask = causal_mask_creation_function(
|
||||
|
@ -284,7 +284,7 @@ def flex_attention_forward(
|
||||
num_local_query_heads = query.shape[1]
|
||||
|
||||
# When running TP this helps:
|
||||
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
|
||||
if (num_local_query_heads & (num_local_query_heads - 1)) != 0:
|
||||
key = repeat_kv(key, query.shape[1] // key.shape[1])
|
||||
value = repeat_kv(value, query.shape[1] // value.shape[1])
|
||||
enable_gqa = False
|
||||
|
@ -45,7 +45,7 @@ def replace_with_vptq_linear(
|
||||
should not be passed by the user.
|
||||
"""
|
||||
|
||||
modules_to_not_convert = ["lm_head"] if not modules_to_not_convert else modules_to_not_convert
|
||||
modules_to_not_convert = modules_to_not_convert if modules_to_not_convert else ["lm_head"]
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
|
@ -167,7 +167,7 @@ class KerasMetricCallback(keras.callbacks.Callback):
|
||||
# If it's a dict with only one key, just return the array
|
||||
if len(outputs) == 1:
|
||||
outputs = list(outputs.values())[0]
|
||||
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
|
||||
elif isinstance(inputs[0], (tuple, list)):
|
||||
outputs = []
|
||||
for input_list in zip(*inputs):
|
||||
outputs.append(self._concatenate_batches(input_list))
|
||||
|
@ -81,9 +81,7 @@ class AyaVisionConfig(PretrainedConfig):
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config["model_type"] = (
|
||||
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
|
||||
)
|
||||
vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = CONFIG_MAPPING["siglip_vision_model"](
|
||||
@ -99,7 +97,7 @@ class AyaVisionConfig(PretrainedConfig):
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "cohere2"
|
||||
text_config["model_type"] = text_config.get("model_type", "cohere2")
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["cohere2"]()
|
||||
|
@ -148,7 +148,7 @@ def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
keep_keys = ["<s>", "<pad>", "</s>", "<unk>"]
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
|
@ -1303,7 +1303,7 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
||||
|
||||
# Check if conditioning_embeds are provided or not, if yes then concatenate the bos_token_id at the end of the conditioning_embeds.
|
||||
# Then we must subtract the positional_ids because during the forward pass it will be added anyways, so we must cancel them out here.
|
||||
conditioning_embeds = model_kwargs.get("conditioning_embeds", None)
|
||||
conditioning_embeds = model_kwargs.get("conditioning_embeds")
|
||||
|
||||
if conditioning_embeds is not None:
|
||||
mel_start_token_embedding = self.model.decoder.input_embeds_layer(
|
||||
|
@ -57,9 +57,7 @@ class Cohere2VisionConfig(PretrainedConfig):
|
||||
self.alignment_intermediate_size = alignment_intermediate_size
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config["model_type"] = (
|
||||
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
|
||||
)
|
||||
vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = CONFIG_MAPPING["siglip_vision_model"](
|
||||
@ -73,7 +71,7 @@ class Cohere2VisionConfig(PretrainedConfig):
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "cohere2"
|
||||
text_config["model_type"] = text_config.get("model_type", "cohere2")
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["cohere2"](tie_word_embeddings=True)
|
||||
|
@ -66,7 +66,7 @@ class DeepseekVLImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
def __init__(self, **kwargs: Unpack[DeepseekVLFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
if kwargs.get("image_mean", None) is None:
|
||||
if kwargs.get("image_mean") is None:
|
||||
background_color = (127, 127, 127)
|
||||
else:
|
||||
background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
|
||||
|
@ -93,11 +93,11 @@ class DeepseekVLHybridImageProcessorFast(BaseImageProcessorFast):
|
||||
high_res_resample = PILImageResampling.BICUBIC
|
||||
|
||||
def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]):
|
||||
if kwargs.get("image_mean", None) is None:
|
||||
if kwargs.get("image_mean") is None:
|
||||
background_color = (127, 127, 127)
|
||||
else:
|
||||
background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
|
||||
if kwargs.get("high_res_image_mean", None) is None:
|
||||
if kwargs.get("high_res_image_mean") is None:
|
||||
high_res_background_color = (127, 127, 127)
|
||||
else:
|
||||
high_res_background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")])
|
||||
|
@ -749,11 +749,11 @@ class DeepseekVLHybridImageProcessorFast(DeepseekVLImageProcessorFast):
|
||||
high_res_resample = PILImageResampling.BICUBIC
|
||||
|
||||
def __init__(self, **kwargs: Unpack[DeepseekVLHybridFastImageProcessorKwargs]):
|
||||
if kwargs.get("image_mean", None) is None:
|
||||
if kwargs.get("image_mean") is None:
|
||||
background_color = (127, 127, 127)
|
||||
else:
|
||||
background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
|
||||
if kwargs.get("high_res_image_mean", None) is None:
|
||||
if kwargs.get("high_res_image_mean") is None:
|
||||
high_res_background_color = (127, 127, 127)
|
||||
else:
|
||||
high_res_background_color = tuple([int(x * 255) for x in kwargs.get("high_res_image_mean")])
|
||||
|
@ -79,7 +79,7 @@ def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
keep_keys = ["<s>", "<pad>", "</s>", "<unk>"]
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
|
@ -71,7 +71,7 @@ class JanusImageProcessorFast(BaseImageProcessorFast):
|
||||
valid_kwargs = JanusFastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[JanusFastImageProcessorKwargs]):
|
||||
if kwargs.get("image_mean", None) is None:
|
||||
if kwargs.get("image_mean") is None:
|
||||
background_color = (127, 127, 127)
|
||||
else:
|
||||
background_color = tuple([int(x * 255) for x in kwargs.get("image_mean")])
|
||||
|
@ -857,9 +857,7 @@ class JetMoePreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, JetMoeParallelExperts):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, JetMoeMoA):
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, JetMoeMoE):
|
||||
elif isinstance(module, (JetMoeMoA, JetMoeMoE)):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
|
@ -260,7 +260,7 @@ class MMGroundingDinoConfig(PretrainedConfig):
|
||||
self.disable_custom_kernels = disable_custom_kernels
|
||||
# Text backbone
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "bert"
|
||||
text_config["model_type"] = text_config.get("model_type", "bert")
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["bert"]()
|
||||
|
@ -268,7 +268,7 @@ class MMGroundingDinoConfig(GroundingDinoConfig, PretrainedConfig):
|
||||
self.disable_custom_kernels = disable_custom_kernels
|
||||
# Text backbone
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "bert"
|
||||
text_config["model_type"] = text_config.get("model_type", "bert")
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["bert"]()
|
||||
|
@ -2794,11 +2794,7 @@ class OneFormerPreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
elif isinstance(module, OneFormerPixelDecoder):
|
||||
nn.init.normal_(module.level_embed, std=0)
|
||||
elif isinstance(module, OneFormerTransformerDecoderLayer):
|
||||
for p in module.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p, gain=xavier_std)
|
||||
elif isinstance(module, OneFormerTransformerDecoderQueryTransformer):
|
||||
elif isinstance(module, (OneFormerTransformerDecoderLayer, OneFormerTransformerDecoderQueryTransformer)):
|
||||
for p in module.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p, gain=xavier_std)
|
||||
|
@ -197,7 +197,7 @@ class FillMaskPipeline(Pipeline):
|
||||
vocab = {}
|
||||
target_ids = []
|
||||
for target in targets:
|
||||
id_ = vocab.get(target, None)
|
||||
id_ = vocab.get(target)
|
||||
if id_ is None:
|
||||
input_ids = self.tokenizer(
|
||||
target,
|
||||
|
@ -427,9 +427,11 @@ class TokenClassificationPipeline(ChunkPipeline):
|
||||
if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
|
||||
current_length = entity["end"] - entity["start"]
|
||||
previous_length = previous_entity["end"] - previous_entity["start"]
|
||||
if current_length > previous_length:
|
||||
previous_entity = entity
|
||||
elif current_length == previous_length and entity["score"] > previous_entity["score"]:
|
||||
if (
|
||||
current_length > previous_length
|
||||
or current_length == previous_length
|
||||
and entity["score"] > previous_entity["score"]
|
||||
):
|
||||
previous_entity = entity
|
||||
else:
|
||||
aggregated_entities.append(previous_entity)
|
||||
|
@ -2637,9 +2637,7 @@ def nested_simplify(obj, decimals=3):
|
||||
return nested_simplify(obj.tolist())
|
||||
elif isinstance(obj, Mapping):
|
||||
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
|
||||
elif isinstance(obj, (str, int, np.int64)):
|
||||
return obj
|
||||
elif obj is None:
|
||||
elif isinstance(obj, (str, int, np.int64)) or obj is None:
|
||||
return obj
|
||||
elif is_torch_available() and isinstance(obj, torch.Tensor):
|
||||
return nested_simplify(obj.tolist(), decimals)
|
||||
|
@ -1784,9 +1784,7 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
pathlib_repo_file = Path(path)
|
||||
file_name = pathlib_repo_file.name
|
||||
suffix = "".join(pathlib_repo_file.suffixes)
|
||||
if file_name == "tekken.json":
|
||||
valid_tokenizer_files.append(file_name)
|
||||
elif suffix in sentencepiece_suffixes:
|
||||
if file_name == "tekken.json" or suffix in sentencepiece_suffixes:
|
||||
valid_tokenizer_files.append(file_name)
|
||||
|
||||
if len(valid_tokenizer_files) == 0:
|
||||
|
@ -877,9 +877,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
|
||||
input_ids = []
|
||||
for ids_or_pair_ids in batch_text_or_text_pairs:
|
||||
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
||||
ids, pair_ids = ids_or_pair_ids, None
|
||||
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
||||
if (
|
||||
not isinstance(ids_or_pair_ids, (list, tuple))
|
||||
or is_split_into_words
|
||||
and not isinstance(ids_or_pair_ids[0], (list, tuple))
|
||||
):
|
||||
ids, pair_ids = ids_or_pair_ids, None
|
||||
else:
|
||||
ids, pair_ids = ids_or_pair_ids
|
||||
|
@ -153,9 +153,7 @@ def find_batch_size(tensors):
|
||||
result = find_batch_size(value)
|
||||
if result is not None:
|
||||
return result
|
||||
elif isinstance(tensors, torch.Tensor):
|
||||
return tensors.shape[0] if len(tensors.shape) >= 1 else None
|
||||
elif isinstance(tensors, np.ndarray):
|
||||
elif isinstance(tensors, (torch.Tensor, np.ndarray)):
|
||||
return tensors.shape[0] if len(tensors.shape) >= 1 else None
|
||||
|
||||
|
||||
@ -634,10 +632,7 @@ class LengthGroupedSampler(Sampler):
|
||||
self.batch_size = batch_size
|
||||
if lengths is None:
|
||||
model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||
if (
|
||||
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
||||
or model_input_name not in dataset[0]
|
||||
):
|
||||
if not isinstance(dataset[0], (dict, BatchEncoding)) or model_input_name not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
f"'{model_input_name}' key."
|
||||
@ -697,10 +692,7 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
|
||||
if lengths is None:
|
||||
model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||
if (
|
||||
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
|
||||
or model_input_name not in dataset[0]
|
||||
):
|
||||
if not isinstance(dataset[0], (dict, BatchEncoding)) or model_input_name not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
f"'{model_input_name}' key."
|
||||
|
@ -1404,8 +1404,8 @@ def _process_regular_parameters(
|
||||
param_type = f"[`{class_name}`]"
|
||||
else:
|
||||
param_type = f"[`{param_type.split('.')[-1]}`]"
|
||||
elif param_type == "" and False: # TODO: Enforce typing for all parameters
|
||||
print(f"🚨 {param_name} for {func.__qualname__} in file {func.__code__.co_filename} has no type")
|
||||
# elif param_type == "" and False: # TODO: Enforce typing for all parameters
|
||||
# print(f"🚨 {param_name} for {func.__qualname__} in file {func.__code__.co_filename} has no type")
|
||||
param_type = param_type if "`" in param_type else f"`{param_type}`"
|
||||
# Format the parameter docstring
|
||||
if additional_info:
|
||||
@ -1833,7 +1833,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
|
||||
docstring += set_min_indent(f"\n{docstring_init}", indent_level)
|
||||
elif is_dataclass:
|
||||
# No init function, we have a data class
|
||||
docstring += "\nArgs:\n" if not docstring_args else docstring_args
|
||||
docstring += docstring_args if docstring_args else "\nArgs:\n"
|
||||
source_args_dict = get_args_doc_from_source(ModelOutputArgs)
|
||||
doc_class = cls.__doc__ if cls.__doc__ else ""
|
||||
documented_kwargs, _ = parse_docstring(doc_class)
|
||||
|
@ -749,9 +749,7 @@ def create_wrapper(
|
||||
tracer = found_proxies[0].tracer
|
||||
if op_type == "call_function":
|
||||
target = function
|
||||
elif op_type == "call_method":
|
||||
target = function.__name__
|
||||
elif op_type == "get_attr":
|
||||
elif op_type == "call_method" or op_type == "get_attr":
|
||||
target = function.__name__
|
||||
else:
|
||||
raise ValueError(f"op_type {op_type} not supported.")
|
||||
|
@ -451,9 +451,7 @@ def get_torch_major_and_minor_version() -> str:
|
||||
|
||||
|
||||
def is_torch_sdpa_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
elif _torch_version == "N/A":
|
||||
if not is_torch_available() or _torch_version == "N/A":
|
||||
return False
|
||||
|
||||
# NOTE: MLU is OK with non-contiguous inputs.
|
||||
@ -467,9 +465,7 @@ def is_torch_sdpa_available():
|
||||
|
||||
|
||||
def is_torch_flex_attn_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
elif _torch_version == "N/A":
|
||||
if not is_torch_available() or _torch_version == "N/A":
|
||||
return False
|
||||
|
||||
# TODO check if some bugs cause push backs on the exact version
|
||||
@ -1088,7 +1084,7 @@ def is_ninja_available():
|
||||
[ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output("ninja --version".split())
|
||||
subprocess.check_output(["ninja", "--version"])
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
|
@ -1222,7 +1222,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
|
||||
# this is just inference, so no optimizer should be loaded
|
||||
# it only works for z3 (makes no sense with z1-z2)
|
||||
fp32 = True if dtype == "fp32" else False
|
||||
fp32 = dtype == "fp32"
|
||||
self.run_and_check(
|
||||
stage=ZERO3,
|
||||
dtype=dtype,
|
||||
@ -1337,13 +1337,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
|
||||
if do_eval:
|
||||
actions += 1
|
||||
args.extend(
|
||||
"""
|
||||
--do_eval
|
||||
--max_eval_samples 16
|
||||
--per_device_eval_batch_size 2
|
||||
""".split()
|
||||
)
|
||||
args.extend(["--do_eval", "--max_eval_samples", "16", "--per_device_eval_batch_size", "2"])
|
||||
|
||||
assert actions > 0, "need at least do_train or do_eval for the test to run"
|
||||
|
||||
|
@ -315,9 +315,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
--eval_steps {str(eval_steps)}
|
||||
""".split()
|
||||
|
||||
args_predict = """
|
||||
--do_predict
|
||||
""".split()
|
||||
args_predict = ["--do_predict"]
|
||||
|
||||
args = []
|
||||
if do_train:
|
||||
@ -330,11 +328,11 @@ class TestTrainerExt(TestCasePlus):
|
||||
args += args_predict
|
||||
|
||||
if predict_with_generate:
|
||||
args += "--predict_with_generate".split()
|
||||
args += ["--predict_with_generate"]
|
||||
|
||||
if do_train:
|
||||
if optim == "adafactor":
|
||||
args += "--adafactor".split()
|
||||
args += ["--adafactor"]
|
||||
else:
|
||||
args += f"--optim {optim}".split()
|
||||
|
||||
|
@ -122,7 +122,7 @@ class TestFSDPGeneration(TestCasePlus):
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
args = "--fsdp".split()
|
||||
args = ["--fsdp"]
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
@ -135,7 +135,7 @@ class TestFSDPGeneration(TestCasePlus):
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
args = "--fsdp2".split()
|
||||
args = ["--fsdp2"]
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
@ -5017,7 +5017,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
|
||||
def get_commit_history(self, repo):
|
||||
commit_logs = subprocess.run(
|
||||
"git log".split(),
|
||||
["git", "log"],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
encoding="utf-8",
|
||||
|
Reference in New Issue
Block a user