Compare commits

...

1 Commits

Author SHA1 Message Date
bde9f6294d Base callback API for agents 2023-06-28 09:54:40 -04:00

View File

@ -188,6 +188,30 @@ def clean_code_for_run(result):
return explanation, code
class AgentCallback:
"""
Class for callbacks used by the [`Agent`] when running [`Agent.run`] or [`Agent.chat`].
"""
def on_prompt_formatted(self, prompt):
"""
This event is called just before the prompt is passed to the LLM powering the agent.
"""
pass
def on_llm_returned(self, result):
"""
This event is called just after the LLM powering the agent finished generating its answer.
"""
pass
def on_result_cleaned(self, code):
"""
This event is called just after the result from the LLM is parsed and the code in its answer is extracted.
"""
pass
class Agent:
"""
Base class for all agents which contains the main API methods.
@ -206,7 +230,7 @@ class Agent:
one of the default tools, that default tool will be overridden.
"""
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, callback=None):
_setup_default_tools()
agent_name = self.__class__.__name__
@ -231,6 +255,7 @@ class Agent:
name = list(replacements.keys())[0]
logger.warn(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
self.callback = callback
self.prepare_for_new_chat()
@property
@ -285,9 +310,16 @@ class Agent:
```
"""
prompt = self.format_prompt(task, chat_mode=True)
if self.callback is not None:
self.callback.on_prompt_formatted(prompt)
result = self.generate_one(prompt, stop=["Human:", "====="])
if self.callback is not None:
self.callback.on_llm_returned(result)
self.chat_history = prompt + result.strip() + "\n"
explanation, code = clean_code_for_chat(result)
if self.callback is not None:
self.callback.on_result_cleaned(code)
self.log(f"==Explanation from the agent==\n{explanation}")
@ -333,8 +365,14 @@ class Agent:
```
"""
prompt = self.format_prompt(task)
if self.callback is not None:
self.callback.on_prompt_formatted(prompt)
result = self.generate_one(prompt, stop=["Task:"])
if self.callback is not None:
self.callback.on_llm_returned(result)
explanation, code = clean_code_for_run(result)
if self.callback is not None:
self.callback.on_result_cleaned(code)
self.log(f"==Explanation from the agent==\n{explanation}")