mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[chat] clean code and add base help (#37892)
This commit is contained in:
@ -44,26 +44,25 @@ if is_torch_available():
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
||||
|
||||
|
||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||
ALLOWED_VALUE_CHARS = set(
|
||||
string.ascii_letters + string.digits + string.whitespace + r".!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~"
|
||||
)
|
||||
|
||||
HELP_STRING = """\
|
||||
|
||||
**TRANSFORMERS CHAT INTERFACE**
|
||||
|
||||
The chat interface is a simple tool to try out a chat model.
|
||||
|
||||
Besides talking to the model there are several commands:
|
||||
- **help**: show this help message
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
"""
|
||||
DEFAULT_EXAMPLES = {
|
||||
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
|
||||
"code": {
|
||||
"text": (
|
||||
"Write a Python function that integrates any Python function f(x) numerically over an arbitrary "
|
||||
"interval [x_start, x_end]."
|
||||
),
|
||||
},
|
||||
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
|
||||
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
|
||||
"birds": {"text": "Why aren't birds real?"},
|
||||
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
||||
}
|
||||
|
||||
SUPPORTED_GENERATION_KWARGS = [
|
||||
"max_new_tokens",
|
||||
@ -75,159 +74,39 @@ SUPPORTED_GENERATION_KWARGS = [
|
||||
"repetition_penalty",
|
||||
]
|
||||
|
||||
DEFAULT_EXAMPLES = {
|
||||
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
|
||||
"code": {
|
||||
"text": "Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]."
|
||||
},
|
||||
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
|
||||
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
|
||||
"birds": {"text": "Why aren't birds real?"},
|
||||
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
||||
}
|
||||
# Printed at the start of a chat session
|
||||
HELP_STRING_MINIMAL = """
|
||||
|
||||
**TRANSFORMERS CHAT INTERFACE**
|
||||
|
||||
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
||||
- **help**: shows all available commands
|
||||
- **clear**: clears the current conversation and starts a new one
|
||||
- **exit**: closes the interface
|
||||
"""
|
||||
|
||||
|
||||
def get_username():
|
||||
if platform.system() == "Windows":
|
||||
return os.getlogin()
|
||||
else:
|
||||
return pwd.getpwuid(os.getuid()).pw_name
|
||||
# Printed when the user types `help` in the chat session
|
||||
HELP_STRING = f"""
|
||||
|
||||
**TRANSFORMERS CHAT INTERFACE HELP**
|
||||
|
||||
def create_default_filename(model_name):
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{model_name}/chat_{time_str}.json"
|
||||
|
||||
|
||||
def save_chat(chat, args, filename):
|
||||
output_dict = {}
|
||||
output_dict["settings"] = vars(args)
|
||||
output_dict["chat_history"] = chat
|
||||
|
||||
folder = args.save_folder
|
||||
|
||||
if filename is None:
|
||||
filename = create_default_filename(args.model_name_or_path)
|
||||
filename = os.path.join(folder, filename)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_dict, f, indent=4)
|
||||
return os.path.abspath(filename)
|
||||
|
||||
|
||||
def clear_chat_history(system_prompt):
|
||||
if system_prompt is None:
|
||||
chat = []
|
||||
else:
|
||||
chat = [{"role": "system", "content": system_prompt}]
|
||||
return chat
|
||||
|
||||
|
||||
def parse_settings(user_input, current_args, interface):
|
||||
settings = user_input[4:].strip().split(";")
|
||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||
settings = dict(settings)
|
||||
error = False
|
||||
|
||||
for name in settings:
|
||||
if hasattr(current_args, name):
|
||||
try:
|
||||
if isinstance(getattr(current_args, name), bool):
|
||||
if settings[name] == "True":
|
||||
settings[name] = True
|
||||
elif settings[name] == "False":
|
||||
settings[name] = False
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||
except ValueError:
|
||||
interface.print_red(
|
||||
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
|
||||
)
|
||||
else:
|
||||
interface.print_red(f"There is no '{name}' setting.")
|
||||
|
||||
if error:
|
||||
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
|
||||
return current_args, False
|
||||
else:
|
||||
for name in settings:
|
||||
setattr(current_args, name, settings[name])
|
||||
interface.print_green(f"Set {name} to {settings[name]}.")
|
||||
|
||||
time.sleep(1.5) # so the user has time to read the changes
|
||||
return current_args, True
|
||||
|
||||
|
||||
def get_quantization_config(model_args) -> Optional["BitsAndBytesConfig"]:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
|
||||
def load_model_and_tokenizer(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = get_quantization_config(args)
|
||||
model_kwargs = {
|
||||
"revision": args.model_revision,
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
|
||||
if tokenizer.pad_token_id is None:
|
||||
pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
all_eos_token_ids = []
|
||||
|
||||
if eos_tokens is not None:
|
||||
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
|
||||
|
||||
if eos_token_ids is not None:
|
||||
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||
|
||||
if len(all_eos_token_ids) == 0:
|
||||
all_eos_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
return pad_token_id, all_eos_token_ids
|
||||
Full command list:
|
||||
- **help**: shows this help message
|
||||
- **clear**: clears the current conversation and starts a new one
|
||||
- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
|
||||
names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
||||
- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
|
||||
separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
||||
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
||||
- **exit**: closes the interface
|
||||
"""
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name=None, user_name=None):
|
||||
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
||||
self._console = Console()
|
||||
if model_name is None:
|
||||
self.model_name = "assistant"
|
||||
@ -238,9 +117,10 @@ class RichInterface:
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output from a role."""
|
||||
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
def stream_output(self, output_stream: TextIteratorStreamer) -> str:
|
||||
"""Stream output from a role, and return the generated text after it's done steaming."""
|
||||
# This method is originally from the FastChat CLI:
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
@ -276,93 +156,54 @@ class RichInterface:
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
def input(self):
|
||||
def input(self) -> str:
|
||||
"""Gets user input from the console."""
|
||||
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
|
||||
self._console.print()
|
||||
return input
|
||||
|
||||
def clear(self):
|
||||
"""Clears the console."""
|
||||
self._console.clear()
|
||||
|
||||
def print_user_message(self, text):
|
||||
def print_user_message(self, text: str):
|
||||
"""Prints a user message to the console."""
|
||||
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_green(self, text):
|
||||
self._console.print(f"[bold green]{text}")
|
||||
def print_color(self, text: str, color: str):
|
||||
"""Prints text in a given color to the console."""
|
||||
self._console.print(f"[bold {color}]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_red(self, text):
|
||||
self._console.print(f"[bold red]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_help(self):
|
||||
self._console.print(Markdown(HELP_STRING))
|
||||
def print_help(self, minimal: bool = False):
|
||||
"""Prints the help message to the console."""
|
||||
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
||||
self._console.print()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatArguments:
|
||||
r"""
|
||||
Arguments for the chat script.
|
||||
Arguments for the chat CLI.
|
||||
|
||||
Args:
|
||||
model_name_or_path (`str`):
|
||||
Name of the pre-trained model.
|
||||
user (`str` or `None`, *optional*, defaults to `None`):
|
||||
Username to display in chat interface.
|
||||
system_prompt (`str` or `None`, *optional*, defaults to `None`):
|
||||
System prompt.
|
||||
save_folder (`str`, *optional*, defaults to `"./chat_history/"`):
|
||||
Folder to save chat history.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
Device to use for inference.
|
||||
examples_path (`str` or `None`, *optional*, defaults to `None`):
|
||||
Path to a yaml file with examples.
|
||||
max_new_tokens (`int`, *optional*, defaults to `256`):
|
||||
Maximum number of tokens to generate.
|
||||
do_sample (`bool`, *optional*, defaults to `True`):
|
||||
Whether to sample outputs during generation.
|
||||
num_beams (`int`, *optional*, defaults to `1`):
|
||||
Number of beams for beam search.
|
||||
temperature (`float`, *optional*, defaults to `1.0`):
|
||||
Temperature parameter for generation.
|
||||
top_k (`int`, *optional*, defaults to `50`):
|
||||
Value of k for top-k sampling.
|
||||
top_p (`float`, *optional*, defaults to `1.0`):
|
||||
Value of p for nucleus sampling.
|
||||
repetition_penalty (`float`, *optional*, defaults to `1.0`):
|
||||
Repetition penalty.
|
||||
eos_tokens (`str` or `None`, *optional*, defaults to `None`):
|
||||
EOS tokens to stop the generation. If multiple they should be comma separated.
|
||||
eos_token_ids (`str` or `None`, *optional*, defaults to `None`):
|
||||
EOS token IDs to stop the generation. If multiple they should be comma separated.
|
||||
model_revision (`str`, *optional*, defaults to `"main"`):
|
||||
Specific model version to use (can be a branch name, tag name or commit id).
|
||||
torch_dtype (`str` or `None`, *optional*, defaults to `None`):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code when loading a model.
|
||||
attn_implementation (`str` or `None`, *optional*, defaults to `None`):
|
||||
Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case
|
||||
you must install this manually by running `pip install flash-attn --no-build-isolation`.
|
||||
load_in_8bit (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use 8 bit precision for the base model - works only with LoRA.
|
||||
load_in_4bit (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use 4 bit precision for the base model - works only with LoRA.
|
||||
bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
|
||||
Quantization type.
|
||||
use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use nested quantization.
|
||||
See the metadata arg for each argument's description -- the medatata will be printed with
|
||||
`transformers chat --help`
|
||||
"""
|
||||
|
||||
# General settings
|
||||
model_name_or_path: Optional[str] = field(default=None, metadata={"help": "Name of the pre-trained model."})
|
||||
user: Optional[str] = field(default=None, metadata={"help": "Username to display in chat interface."})
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name of the pre-trained model. The positional argument will take precedence if both are passed."
|
||||
},
|
||||
)
|
||||
user: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Username to display in chat interface. Defaults to the current user's name."},
|
||||
)
|
||||
system_prompt: Optional[str] = field(default=None, metadata={"help": "System prompt."})
|
||||
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history."})
|
||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
||||
|
||||
# Generation settings
|
||||
@ -387,6 +228,7 @@ class ChatArguments:
|
||||
default="main",
|
||||
metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
@ -434,7 +276,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
parser: Root parser to register command-specific arguments
|
||||
"""
|
||||
dataclass_types = (ChatArguments,)
|
||||
chat_parser = parser.add_parser("chat", help=HELP_STRING, dataclass_types=dataclass_types)
|
||||
chat_parser = parser.add_parser("chat", dataclass_types=dataclass_types)
|
||||
|
||||
group = chat_parser.add_argument_group("Positional arguments")
|
||||
group.add_argument(
|
||||
@ -447,10 +289,123 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
|
||||
|
||||
if args.model_name_or_path is None:
|
||||
raise ValueError("--model_name_or_path required for chat command.")
|
||||
raise ValueError(
|
||||
"One of the following must be provided:"
|
||||
"\n- The positional argument containing the model repo;"
|
||||
"\n- the optional --model_name_or_path argument, containing the model repo"
|
||||
"\ne.g. transformers chat <model_repo> or transformers chat --model_name_or_path <model_repo>"
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Chat session methods
|
||||
@staticmethod
|
||||
def get_username() -> str:
|
||||
"""Returns the username of the current user."""
|
||||
if platform.system() == "Windows":
|
||||
return os.getlogin()
|
||||
else:
|
||||
return pwd.getpwuid(os.getuid()).pw_name
|
||||
|
||||
@staticmethod
|
||||
def save_chat(chat, args: ChatArguments, filename: Optional[str] = None) -> str:
|
||||
"""Saves the chat history to a file."""
|
||||
output_dict = {}
|
||||
output_dict["settings"] = vars(args)
|
||||
output_dict["chat_history"] = chat
|
||||
|
||||
folder = args.save_folder
|
||||
|
||||
if filename is None:
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = f"{args.model_name_or_path}/chat_{time_str}.json"
|
||||
filename = os.path.join(folder, filename)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_dict, f, indent=4)
|
||||
return os.path.abspath(filename)
|
||||
|
||||
@staticmethod
|
||||
def clear_chat_history(system_prompt: Optional[str] = None) -> list[dict]:
|
||||
"""Clears the chat history."""
|
||||
if system_prompt is None:
|
||||
chat = []
|
||||
else:
|
||||
chat = [{"role": "system", "content": system_prompt}]
|
||||
return chat
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Input parsing methods
|
||||
@staticmethod
|
||||
def parse_settings(
|
||||
user_input: str, current_args: ChatArguments, interface: RichInterface
|
||||
) -> tuple[ChatArguments, bool]:
|
||||
"""Parses the settings from the user input into the CLI arguments."""
|
||||
settings = user_input[4:].strip().split(";")
|
||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||
settings = dict(settings)
|
||||
error = False
|
||||
|
||||
for name in settings:
|
||||
if hasattr(current_args, name):
|
||||
try:
|
||||
if isinstance(getattr(current_args, name), bool):
|
||||
if settings[name] == "True":
|
||||
settings[name] = True
|
||||
elif settings[name] == "False":
|
||||
settings[name] = False
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||
except ValueError:
|
||||
error = True
|
||||
interface.print_color(
|
||||
text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
|
||||
color="red",
|
||||
)
|
||||
else:
|
||||
interface.print_color(text=f"There is no '{name}' setting.", color="red")
|
||||
|
||||
if error:
|
||||
interface.print_color(
|
||||
text="There was an issue parsing the settings. No settings have been changed.",
|
||||
color="red",
|
||||
)
|
||||
else:
|
||||
for name in settings:
|
||||
setattr(current_args, name, settings[name])
|
||||
interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
|
||||
|
||||
time.sleep(1.5) # so the user has time to read the changes
|
||||
|
||||
return current_args, not error
|
||||
|
||||
@staticmethod
|
||||
def parse_eos_tokens(
|
||||
tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str]
|
||||
) -> tuple[int, list[int]]:
|
||||
"""Retrieves the pad token ID and all possible EOS token IDs."""
|
||||
if tokenizer.pad_token_id is None:
|
||||
pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
all_eos_token_ids = []
|
||||
|
||||
if eos_tokens is not None:
|
||||
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
|
||||
|
||||
if eos_token_ids is not None:
|
||||
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||
|
||||
if len(all_eos_token_ids) == 0:
|
||||
all_eos_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
return pad_token_id, all_eos_token_ids
|
||||
|
||||
@staticmethod
|
||||
def is_valid_setting_command(s: str) -> bool:
|
||||
# First check the basic structure
|
||||
@ -481,6 +436,55 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
return True
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Model loading and performance automation methods
|
||||
@staticmethod
|
||||
def get_quantization_config(model_args: ChatArguments) -> Optional["BitsAndBytesConfig"]:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
# For consistency with model weights, we use the same value as `torch_dtype`
|
||||
bnb_4bit_compute_dtype=model_args.torch_dtype,
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = self.get_quantization_config(args)
|
||||
model_kwargs = {
|
||||
"revision": args.model_revision,
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Main logic
|
||||
def run(self):
|
||||
if not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
@ -497,24 +501,27 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
current_args = copy.deepcopy(args)
|
||||
|
||||
if args.user is None:
|
||||
user = get_username()
|
||||
user = self.get_username()
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(args)
|
||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
chat = self.clear_chat_history(current_args.system_prompt)
|
||||
|
||||
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
||||
interface.print_help(minimal=True)
|
||||
while True:
|
||||
try:
|
||||
user_input = interface.input()
|
||||
|
||||
if user_input == "clear":
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
chat = self.clear_chat_history(current_args.system_prompt)
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
@ -528,7 +535,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
if user_input == "reset":
|
||||
interface.clear()
|
||||
current_args = copy.deepcopy(args)
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
chat = self.clear_chat_history(current_args.system_prompt)
|
||||
continue
|
||||
|
||||
if user_input.startswith("save") and len(user_input.split()) < 2:
|
||||
@ -538,12 +545,12 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
filename = split_input[1]
|
||||
else:
|
||||
filename = None
|
||||
filename = save_chat(chat, current_args, filename)
|
||||
interface.print_green(f"Chat saved in {filename}!")
|
||||
filename = self.save_chat(chat, current_args, filename)
|
||||
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
||||
continue
|
||||
|
||||
if self.is_valid_setting_command(user_input):
|
||||
current_args, success = parse_settings(user_input, current_args, interface)
|
||||
current_args, success = self.parse_settings(user_input, current_args, interface)
|
||||
if success:
|
||||
chat = []
|
||||
interface.clear()
|
||||
@ -557,9 +564,10 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
interface.print_user_message(examples[example_name]["text"])
|
||||
user_input = examples[example_name]["text"]
|
||||
else:
|
||||
interface.print_red(
|
||||
example_error = (
|
||||
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
||||
)
|
||||
interface.print_color(text=example_error, color="red")
|
||||
continue
|
||||
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
Reference in New Issue
Block a user