[chat] clean code and add base help (#37892)

This commit is contained in:
Joao Gante
2025-05-01 15:12:18 +01:00
committed by GitHub
parent 5b573bebb9
commit 410aa01901

View File

@ -44,26 +44,25 @@ if is_torch_available():
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
ALLOWED_VALUE_CHARS = set(
string.ascii_letters + string.digits + string.whitespace + r".!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~"
)
HELP_STRING = """\
**TRANSFORMERS CHAT INTERFACE**
The chat interface is a simple tool to try out a chat model.
Besides talking to the model there are several commands:
- **help**: show this help message
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
"""
DEFAULT_EXAMPLES = {
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
"code": {
"text": (
"Write a Python function that integrates any Python function f(x) numerically over an arbitrary "
"interval [x_start, x_end]."
),
},
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
"birds": {"text": "Why aren't birds real?"},
"socks": {"text": "Why is it important to eat socks after meditating?"},
}
SUPPORTED_GENERATION_KWARGS = [
"max_new_tokens",
@ -75,159 +74,39 @@ SUPPORTED_GENERATION_KWARGS = [
"repetition_penalty",
]
DEFAULT_EXAMPLES = {
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
"code": {
"text": "Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]."
},
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
"birds": {"text": "Why aren't birds real?"},
"socks": {"text": "Why is it important to eat socks after meditating?"},
}
# Printed at the start of a chat session
HELP_STRING_MINIMAL = """
**TRANSFORMERS CHAT INTERFACE**
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- **help**: shows all available commands
- **clear**: clears the current conversation and starts a new one
- **exit**: closes the interface
"""
def get_username():
if platform.system() == "Windows":
return os.getlogin()
else:
return pwd.getpwuid(os.getuid()).pw_name
# Printed when the user types `help` in the chat session
HELP_STRING = f"""
**TRANSFORMERS CHAT INTERFACE HELP**
def create_default_filename(model_name):
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
return f"{model_name}/chat_{time_str}.json"
def save_chat(chat, args, filename):
output_dict = {}
output_dict["settings"] = vars(args)
output_dict["chat_history"] = chat
folder = args.save_folder
if filename is None:
filename = create_default_filename(args.model_name_or_path)
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(output_dict, f, indent=4)
return os.path.abspath(filename)
def clear_chat_history(system_prompt):
if system_prompt is None:
chat = []
else:
chat = [{"role": "system", "content": system_prompt}]
return chat
def parse_settings(user_input, current_args, interface):
settings = user_input[4:].strip().split(";")
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
settings = dict(settings)
error = False
for name in settings:
if hasattr(current_args, name):
try:
if isinstance(getattr(current_args, name), bool):
if settings[name] == "True":
settings[name] = True
elif settings[name] == "False":
settings[name] = False
else:
raise ValueError
else:
settings[name] = type(getattr(current_args, name))(settings[name])
except ValueError:
interface.print_red(
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
)
else:
interface.print_red(f"There is no '{name}' setting.")
if error:
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
return current_args, False
else:
for name in settings:
setattr(current_args, name, settings[name])
interface.print_green(f"Set {name} to {settings[name]}.")
time.sleep(1.5) # so the user has time to read the changes
return current_args, True
def get_quantization_config(model_args) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.torch_dtype,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def load_model_and_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = get_quantization_config(args)
model_kwargs = {
"revision": args.model_revision,
"attn_implementation": args.attn_implementation,
"torch_dtype": torch_dtype,
"device_map": "auto",
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
)
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
return model, tokenizer
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
if tokenizer.pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
else:
pad_token_id = tokenizer.pad_token_id
all_eos_token_ids = []
if eos_tokens is not None:
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
if eos_token_ids is not None:
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
if len(all_eos_token_ids) == 0:
all_eos_token_ids.append(tokenizer.eos_token_id)
return pad_token_id, all_eos_token_ids
Full command list:
- **help**: shows this help message
- **clear**: clears the current conversation and starts a new one
- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
- **exit**: closes the interface
"""
class RichInterface:
def __init__(self, model_name=None, user_name=None):
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
self._console = Console()
if model_name is None:
self.model_name = "assistant"
@ -238,9 +117,10 @@ class RichInterface:
else:
self.user_name = user_name
def stream_output(self, output_stream):
"""Stream output from a role."""
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
def stream_output(self, output_stream: TextIteratorStreamer) -> str:
"""Stream output from a role, and return the generated text after it's done steaming."""
# This method is originally from the FastChat CLI:
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
# Create a Live context for updating the console output
text = ""
self._console.print(f"[bold blue]<{self.model_name}>:")
@ -276,93 +156,54 @@ class RichInterface:
self._console.print()
return text
def input(self):
def input(self) -> str:
"""Gets user input from the console."""
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
self._console.print()
return input
def clear(self):
"""Clears the console."""
self._console.clear()
def print_user_message(self, text):
def print_user_message(self, text: str):
"""Prints a user message to the console."""
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
self._console.print()
def print_green(self, text):
self._console.print(f"[bold green]{text}")
def print_color(self, text: str, color: str):
"""Prints text in a given color to the console."""
self._console.print(f"[bold {color}]{text}")
self._console.print()
def print_red(self, text):
self._console.print(f"[bold red]{text}")
self._console.print()
def print_help(self):
self._console.print(Markdown(HELP_STRING))
def print_help(self, minimal: bool = False):
"""Prints the help message to the console."""
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
self._console.print()
@dataclass
class ChatArguments:
r"""
Arguments for the chat script.
Arguments for the chat CLI.
Args:
model_name_or_path (`str`):
Name of the pre-trained model.
user (`str` or `None`, *optional*, defaults to `None`):
Username to display in chat interface.
system_prompt (`str` or `None`, *optional*, defaults to `None`):
System prompt.
save_folder (`str`, *optional*, defaults to `"./chat_history/"`):
Folder to save chat history.
device (`str`, *optional*, defaults to `"cpu"`):
Device to use for inference.
examples_path (`str` or `None`, *optional*, defaults to `None`):
Path to a yaml file with examples.
max_new_tokens (`int`, *optional*, defaults to `256`):
Maximum number of tokens to generate.
do_sample (`bool`, *optional*, defaults to `True`):
Whether to sample outputs during generation.
num_beams (`int`, *optional*, defaults to `1`):
Number of beams for beam search.
temperature (`float`, *optional*, defaults to `1.0`):
Temperature parameter for generation.
top_k (`int`, *optional*, defaults to `50`):
Value of k for top-k sampling.
top_p (`float`, *optional*, defaults to `1.0`):
Value of p for nucleus sampling.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Repetition penalty.
eos_tokens (`str` or `None`, *optional*, defaults to `None`):
EOS tokens to stop the generation. If multiple they should be comma separated.
eos_token_ids (`str` or `None`, *optional*, defaults to `None`):
EOS token IDs to stop the generation. If multiple they should be comma separated.
model_revision (`str`, *optional*, defaults to `"main"`):
Specific model version to use (can be a branch name, tag name or commit id).
torch_dtype (`str` or `None`, *optional*, defaults to `None`):
Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype
will be automatically derived from the model's weights.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether to trust remote code when loading a model.
attn_implementation (`str` or `None`, *optional*, defaults to `None`):
Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case
you must install this manually by running `pip install flash-attn --no-build-isolation`.
load_in_8bit (`bool`, *optional*, defaults to `False`):
Whether to use 8 bit precision for the base model - works only with LoRA.
load_in_4bit (`bool`, *optional*, defaults to `False`):
Whether to use 4 bit precision for the base model - works only with LoRA.
bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
Quantization type.
use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
Whether to use nested quantization.
See the metadata arg for each argument's description -- the medatata will be printed with
`transformers chat --help`
"""
# General settings
model_name_or_path: Optional[str] = field(default=None, metadata={"help": "Name of the pre-trained model."})
user: Optional[str] = field(default=None, metadata={"help": "Username to display in chat interface."})
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Name of the pre-trained model. The positional argument will take precedence if both are passed."
},
)
user: Optional[str] = field(
default=None,
metadata={"help": "Username to display in chat interface. Defaults to the current user's name."},
)
system_prompt: Optional[str] = field(default=None, metadata={"help": "System prompt."})
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history."})
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
# Generation settings
@ -387,6 +228,7 @@ class ChatArguments:
default="main",
metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."},
)
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
torch_dtype: Optional[str] = field(
default="auto",
metadata={
@ -434,7 +276,7 @@ class ChatCommand(BaseTransformersCLICommand):
parser: Root parser to register command-specific arguments
"""
dataclass_types = (ChatArguments,)
chat_parser = parser.add_parser("chat", help=HELP_STRING, dataclass_types=dataclass_types)
chat_parser = parser.add_parser("chat", dataclass_types=dataclass_types)
group = chat_parser.add_argument_group("Positional arguments")
group.add_argument(
@ -447,10 +289,123 @@ class ChatCommand(BaseTransformersCLICommand):
args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
if args.model_name_or_path is None:
raise ValueError("--model_name_or_path required for chat command.")
raise ValueError(
"One of the following must be provided:"
"\n- The positional argument containing the model repo;"
"\n- the optional --model_name_or_path argument, containing the model repo"
"\ne.g. transformers chat <model_repo> or transformers chat --model_name_or_path <model_repo>"
)
self.args = args
# -----------------------------------------------------------------------------------------------------------------
# Chat session methods
@staticmethod
def get_username() -> str:
"""Returns the username of the current user."""
if platform.system() == "Windows":
return os.getlogin()
else:
return pwd.getpwuid(os.getuid()).pw_name
@staticmethod
def save_chat(chat, args: ChatArguments, filename: Optional[str] = None) -> str:
"""Saves the chat history to a file."""
output_dict = {}
output_dict["settings"] = vars(args)
output_dict["chat_history"] = chat
folder = args.save_folder
if filename is None:
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{args.model_name_or_path}/chat_{time_str}.json"
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(output_dict, f, indent=4)
return os.path.abspath(filename)
@staticmethod
def clear_chat_history(system_prompt: Optional[str] = None) -> list[dict]:
"""Clears the chat history."""
if system_prompt is None:
chat = []
else:
chat = [{"role": "system", "content": system_prompt}]
return chat
# -----------------------------------------------------------------------------------------------------------------
# Input parsing methods
@staticmethod
def parse_settings(
user_input: str, current_args: ChatArguments, interface: RichInterface
) -> tuple[ChatArguments, bool]:
"""Parses the settings from the user input into the CLI arguments."""
settings = user_input[4:].strip().split(";")
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
settings = dict(settings)
error = False
for name in settings:
if hasattr(current_args, name):
try:
if isinstance(getattr(current_args, name), bool):
if settings[name] == "True":
settings[name] = True
elif settings[name] == "False":
settings[name] = False
else:
raise ValueError
else:
settings[name] = type(getattr(current_args, name))(settings[name])
except ValueError:
error = True
interface.print_color(
text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
color="red",
)
else:
interface.print_color(text=f"There is no '{name}' setting.", color="red")
if error:
interface.print_color(
text="There was an issue parsing the settings. No settings have been changed.",
color="red",
)
else:
for name in settings:
setattr(current_args, name, settings[name])
interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
time.sleep(1.5) # so the user has time to read the changes
return current_args, not error
@staticmethod
def parse_eos_tokens(
tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str]
) -> tuple[int, list[int]]:
"""Retrieves the pad token ID and all possible EOS token IDs."""
if tokenizer.pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
else:
pad_token_id = tokenizer.pad_token_id
all_eos_token_ids = []
if eos_tokens is not None:
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
if eos_token_ids is not None:
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
if len(all_eos_token_ids) == 0:
all_eos_token_ids.append(tokenizer.eos_token_id)
return pad_token_id, all_eos_token_ids
@staticmethod
def is_valid_setting_command(s: str) -> bool:
# First check the basic structure
@ -481,6 +436,55 @@ class ChatCommand(BaseTransformersCLICommand):
return True
# -----------------------------------------------------------------------------------------------------------------
# Model loading and performance automation methods
@staticmethod
def get_quantization_config(model_args: ChatArguments) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
# For consistency with model weights, we use the same value as `torch_dtype`
bnb_4bit_compute_dtype=model_args.torch_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.torch_dtype,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = self.get_quantization_config(args)
model_kwargs = {
"revision": args.model_revision,
"attn_implementation": args.attn_implementation,
"torch_dtype": torch_dtype,
"device_map": "auto",
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
)
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
return model, tokenizer
# -----------------------------------------------------------------------------------------------------------------
# Main logic
def run(self):
if not is_rich_available():
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
@ -497,24 +501,27 @@ class ChatCommand(BaseTransformersCLICommand):
current_args = copy.deepcopy(args)
if args.user is None:
user = get_username()
user = self.get_username()
else:
user = args.user
model, tokenizer = load_model_and_tokenizer(args)
model, tokenizer = self.load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
chat = self.clear_chat_history(current_args.system_prompt)
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
interface.print_help(minimal=True)
while True:
try:
user_input = interface.input()
if user_input == "clear":
chat = clear_chat_history(current_args.system_prompt)
chat = self.clear_chat_history(current_args.system_prompt)
interface.clear()
continue
@ -528,7 +535,7 @@ class ChatCommand(BaseTransformersCLICommand):
if user_input == "reset":
interface.clear()
current_args = copy.deepcopy(args)
chat = clear_chat_history(current_args.system_prompt)
chat = self.clear_chat_history(current_args.system_prompt)
continue
if user_input.startswith("save") and len(user_input.split()) < 2:
@ -538,12 +545,12 @@ class ChatCommand(BaseTransformersCLICommand):
filename = split_input[1]
else:
filename = None
filename = save_chat(chat, current_args, filename)
interface.print_green(f"Chat saved in {filename}!")
filename = self.save_chat(chat, current_args, filename)
interface.print_color(text=f"Chat saved in {filename}!", color="green")
continue
if self.is_valid_setting_command(user_input):
current_args, success = parse_settings(user_input, current_args, interface)
current_args, success = self.parse_settings(user_input, current_args, interface)
if success:
chat = []
interface.clear()
@ -557,9 +564,10 @@ class ChatCommand(BaseTransformersCLICommand):
interface.print_user_message(examples[example_name]["text"])
user_input = examples[example_name]["text"]
else:
interface.print_red(
example_error = (
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
)
interface.print_color(text=example_error, color="red")
continue
chat.append({"role": "user", "content": user_input})