mirror of https://github.com/microsoft/autogen.git
426 lines
21 KiB
Python
426 lines
21 KiB
Python
import os
|
|
from autogen import oai
|
|
from autogen.agentchat.agent import Agent
|
|
from autogen.agentchat.assistant_agent import ConversableAgent
|
|
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
|
|
from typing import Callable, Dict, Optional, Union, List, Tuple, Any
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
import pickle
|
|
|
|
|
|
try:
|
|
from termcolor import colored
|
|
except ImportError:
|
|
|
|
def colored(x, *args, **kwargs):
|
|
return x
|
|
|
|
|
|
class TeachableAgent(ConversableAgent):
|
|
"""Teachable Agent, a subclass of ConversableAgent using a vector database to remember user teachings.
|
|
In this class, the term 'user' refers to any caller (human or not) sending messages to this agent.
|
|
Not yet tested in the group-chat setting."""
|
|
|
|
def __init__(
|
|
self,
|
|
name="teachableagent",
|
|
system_message: Optional[
|
|
str
|
|
] = "You are a helpful AI assistant that remembers user teachings from prior chats.",
|
|
human_input_mode: Optional[str] = "NEVER",
|
|
llm_config: Optional[Union[Dict, bool]] = None,
|
|
analyzer_llm_config: Optional[Union[Dict, bool]] = None,
|
|
teach_config: Optional[Dict] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
name (str): name of the agent.
|
|
system_message (str): system message for the ChatCompletion inference.
|
|
human_input_mode (str): This agent should NEVER prompt the human for input.
|
|
llm_config (dict or False): llm inference configuration.
|
|
Please refer to [Completion.create](/docs/reference/oai/completion#create)
|
|
for available options.
|
|
To disable llm-based auto reply, set to False.
|
|
analyzer_llm_config (dict or False): llm inference configuration passed to TextAnalyzerAgent.
|
|
Given the default setting of None, TeachableAgent passes its own llm_config to TextAnalyzerAgent.
|
|
teach_config (dict or None): Additional parameters used by TeachableAgent.
|
|
To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
|
|
- verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
|
|
- reset_db (Optional, bool): True to clear the DB before starting. Default False.
|
|
- path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
|
|
- prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
|
|
- recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
|
|
- max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
|
|
**kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
|
|
"""
|
|
super().__init__(
|
|
name=name,
|
|
system_message=system_message,
|
|
human_input_mode=human_input_mode,
|
|
llm_config=llm_config,
|
|
**kwargs,
|
|
)
|
|
# Register a custom reply function.
|
|
self.register_reply(Agent, TeachableAgent._generate_teachable_assistant_reply, 1)
|
|
|
|
# Assemble the parameter settings.
|
|
self._teach_config = {} if teach_config is None else teach_config
|
|
self.verbosity = self._teach_config.get("verbosity", 0)
|
|
self.reset_db = self._teach_config.get("reset_db", False)
|
|
self.path_to_db_dir = self._teach_config.get("path_to_db_dir", "./tmp/teachable_agent_db")
|
|
self.prepopulate = self._teach_config.get("prepopulate", True)
|
|
self.recall_threshold = self._teach_config.get("recall_threshold", 1.5)
|
|
self.max_num_retrievals = self._teach_config.get("max_num_retrievals", 10)
|
|
|
|
# Create the analyzer.
|
|
if analyzer_llm_config is None:
|
|
analyzer_llm_config = llm_config
|
|
self.analyzer = TextAnalyzerAgent(llm_config=analyzer_llm_config)
|
|
|
|
# Create the memo store.
|
|
self.memo_store = MemoStore(self.verbosity, self.reset_db, self.path_to_db_dir)
|
|
self.user_comments = [] # Stores user comments until the end of each chat.
|
|
|
|
def close_db(self):
|
|
"""Cleanly closes the memo store."""
|
|
self.memo_store.close()
|
|
|
|
def prepopulate_db(self):
|
|
"""Adds a few arbitrary memos to the DB."""
|
|
self.memo_store.prepopulate()
|
|
|
|
def _generate_teachable_assistant_reply(
|
|
self,
|
|
messages: Optional[List[Dict]] = None,
|
|
sender: Optional[Agent] = None,
|
|
config: Optional[Any] = None, # Persistent state.
|
|
) -> Tuple[bool, Union[str, Dict, None]]:
|
|
"""
|
|
Generates a reply to the last user message, after querying the memo store for relevant information.
|
|
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
|
|
"""
|
|
if self.llm_config is False:
|
|
raise ValueError("TeachableAgent requires self.llm_config to be set in its base class.")
|
|
if messages is None:
|
|
messages = self._oai_messages[sender] # In case of a direct call.
|
|
|
|
# Get the last user turn.
|
|
last_message = messages[-1]
|
|
user_text = last_message["content"]
|
|
if (not isinstance(user_text, str)) or ("context" in last_message):
|
|
raise ValueError(
|
|
"TeachableAgent currently assumes that the message content is a simple string. This error serves to flag a test case for relaxing this assumption."
|
|
)
|
|
|
|
# Keep track of this user turn as a potential source of memos later.
|
|
self.user_comments.append(user_text)
|
|
|
|
# Consider whether to retrieve something from the DB.
|
|
if self.memo_store.last_memo_id > 0:
|
|
new_user_text = self.consider_memo_retrieval(user_text)
|
|
if new_user_text != user_text:
|
|
# Make a copy of the message list, and replace the last user message with the new one.
|
|
messages = messages.copy()
|
|
messages[-1]["content"] = new_user_text
|
|
|
|
# Generate a response.
|
|
msgs = self._oai_system_message + messages
|
|
response = oai.ChatCompletion.create(messages=msgs, **self.llm_config)
|
|
response_text = oai.ChatCompletion.extract_text_or_function_call(response)[0]
|
|
return True, response_text
|
|
|
|
def learn_from_user_feedback(self):
|
|
"""Reviews the user comments from the last chat, and decides what teachings to store as memos."""
|
|
print(colored("\nREVIEWING CHAT FOR USER TEACHINGS TO REMEMBER", "light_yellow"))
|
|
# Look at each user turn.
|
|
if len(self.user_comments) > 0:
|
|
for comment in self.user_comments:
|
|
# Consider whether to store something from this user turn in the DB.
|
|
self.consider_memo_storage(comment)
|
|
self.user_comments = []
|
|
|
|
def consider_memo_storage(self, comment):
|
|
"""Decides whether to store something from one user comment in the DB."""
|
|
# Check for a problem-solution pair.
|
|
response = self.analyze(
|
|
comment,
|
|
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
|
|
)
|
|
if "yes" in response.lower():
|
|
# Can we extract advice?
|
|
advice = self.analyze(
|
|
comment,
|
|
"Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
|
|
)
|
|
if "none" not in advice.lower():
|
|
# Yes. Extract the task.
|
|
task = self.analyze(
|
|
comment,
|
|
"Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
|
|
)
|
|
# Generalize the task.
|
|
general_task = self.analyze(
|
|
task,
|
|
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
|
|
)
|
|
# Add the task-advice (problem-solution) pair to the vector DB.
|
|
if self.verbosity >= 1:
|
|
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
|
|
self.memo_store.add_input_output_pair(general_task, advice)
|
|
|
|
# Check for information to be learned.
|
|
response = self.analyze(
|
|
comment,
|
|
"Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
|
|
)
|
|
if "yes" in response.lower():
|
|
# Yes. What question would this information answer?
|
|
question = self.analyze(
|
|
comment,
|
|
"Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
|
|
)
|
|
# Extract the information.
|
|
answer = self.analyze(
|
|
comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
|
|
)
|
|
# Add the question-answer pair to the vector DB.
|
|
if self.verbosity >= 1:
|
|
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
|
|
self.memo_store.add_input_output_pair(question, answer)
|
|
|
|
def consider_memo_retrieval(self, comment):
|
|
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
|
|
|
|
# First, use the user comment directly as the lookup key.
|
|
if self.verbosity >= 1:
|
|
print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
|
|
memo_list = self.retrieve_relevant_memos(comment)
|
|
|
|
# Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
|
|
response = self.analyze(
|
|
comment,
|
|
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
|
|
)
|
|
if "yes" in response.lower():
|
|
if self.verbosity >= 1:
|
|
print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
|
|
# Extract the task.
|
|
task = self.analyze(
|
|
comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
|
|
)
|
|
# Generalize the task.
|
|
general_task = self.analyze(
|
|
task,
|
|
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
|
|
)
|
|
# Append any relevant memos.
|
|
memo_list.extend(self.retrieve_relevant_memos(general_task))
|
|
|
|
# De-duplicate the memo list.
|
|
memo_list = list(set(memo_list))
|
|
|
|
# Append the memos to the last user message.
|
|
return comment + self.concatenate_memo_texts(memo_list)
|
|
|
|
def retrieve_relevant_memos(self, input_text):
|
|
"""Returns semantically related memos from the DB."""
|
|
memo_list = self.memo_store.get_related_memos(
|
|
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
|
|
)
|
|
|
|
if self.verbosity >= 1:
|
|
# Was anything retrieved?
|
|
if len(memo_list) == 0:
|
|
# No. Look at the closest memo.
|
|
print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow"))
|
|
self.memo_store.get_nearest_memo(input_text)
|
|
print() # Print a blank line. The memo details were printed by get_nearest_memo().
|
|
|
|
# Create a list of just the memo output_text strings.
|
|
memo_list = [memo[1] for memo in memo_list]
|
|
return memo_list
|
|
|
|
def concatenate_memo_texts(self, memo_list):
|
|
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
|
|
memo_texts = ""
|
|
if len(memo_list) > 0:
|
|
info = "\n# Memories that might help\n"
|
|
for memo in memo_list:
|
|
info = info + "- " + memo + "\n"
|
|
if self.verbosity >= 1:
|
|
print(colored("\nMEMOS APPENDED TO LAST USER MESSAGE...\n" + info + "\n", "light_yellow"))
|
|
memo_texts = memo_texts + "\n" + info
|
|
return memo_texts
|
|
|
|
def analyze(self, text_to_analyze, analysis_instructions):
|
|
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
|
|
if self.verbosity >= 2:
|
|
# Use the messaging mechanism so that the analyzer's messages are included in the printed chat.
|
|
self.analyzer.reset() # Clear the analyzer's list of messages.
|
|
self.send(
|
|
recipient=self.analyzer, message=text_to_analyze, request_reply=False
|
|
) # Put the message in the analyzer's list.
|
|
self.send(recipient=self.analyzer, message=analysis_instructions, request_reply=True) # Request the reply.
|
|
return self.last_message(self.analyzer)["content"]
|
|
else:
|
|
# Use the analyzer's method directly, to leave analyzer message out of the printed chat.
|
|
return self.analyzer.analyze_text(text_to_analyze, analysis_instructions)
|
|
|
|
|
|
class MemoStore:
|
|
"""
|
|
Provides memory storage and retrieval for a TeachableAgent, using a vector database.
|
|
Each DB entry (called a memo) is a pair of strings: an input text and an output text.
|
|
The input text might be a question, or a task to perform.
|
|
The output text might be an answer to the question, or advice on how to perform the task.
|
|
Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
|
|
"""
|
|
|
|
def __init__(self, verbosity, reset, path_to_db_dir):
|
|
"""
|
|
Args:
|
|
- verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
|
|
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
|
|
"""
|
|
self.verbosity = verbosity
|
|
self.reset = reset
|
|
self.path_to_db_dir = path_to_db_dir
|
|
|
|
# Load or create the vector DB on disk.
|
|
settings = Settings(
|
|
anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
|
|
)
|
|
self.db_client = chromadb.Client(settings)
|
|
self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
|
|
if reset:
|
|
self.reset_db()
|
|
|
|
# Load or create the associated memo dict on disk.
|
|
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
|
|
self.uid_text_dict = {}
|
|
self.last_memo_id = 0
|
|
if (not reset) and os.path.exists(self.path_to_dict):
|
|
print(colored("\nLOADING MEMORY FROM DISK", "light_green"))
|
|
print(colored(" Location = {}".format(self.path_to_dict), "light_green"))
|
|
with open(self.path_to_dict, "rb") as f:
|
|
self.uid_text_dict = pickle.load(f)
|
|
self.last_memo_id = len(self.uid_text_dict)
|
|
if self.verbosity >= 3:
|
|
self.list_memos()
|
|
|
|
def list_memos(self):
|
|
"""Prints the contents of MemoStore."""
|
|
print(colored("LIST OF MEMOS", "light_green"))
|
|
for uid, text in self.uid_text_dict.items():
|
|
input_text, output_text = text
|
|
print(
|
|
colored(
|
|
" ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text),
|
|
"light_green",
|
|
)
|
|
)
|
|
|
|
def close(self):
|
|
"""Saves self.uid_text_dict to disk."""
|
|
print(colored("\nSAVING MEMORY TO DISK", "light_green"))
|
|
print(colored(" Location = {}".format(self.path_to_dict), "light_green"))
|
|
with open(self.path_to_dict, "wb") as file:
|
|
pickle.dump(self.uid_text_dict, file)
|
|
|
|
def reset_db(self):
|
|
"""Forces immediate deletion of the DB's contents, in memory and on disk."""
|
|
print(colored("\nCLEARING MEMORY", "light_green"))
|
|
self.db_client.delete_collection("memos")
|
|
self.vec_db = self.db_client.create_collection("memos")
|
|
self.uid_text_dict = {}
|
|
|
|
def add_input_output_pair(self, input_text, output_text):
|
|
"""Adds an input-output pair to the vector DB."""
|
|
self.last_memo_id += 1
|
|
self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
|
|
self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text
|
|
if self.verbosity >= 1:
|
|
print(
|
|
colored(
|
|
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}".format(
|
|
self.last_memo_id, input_text, output_text
|
|
),
|
|
"light_green",
|
|
)
|
|
)
|
|
if self.verbosity >= 3:
|
|
self.list_memos()
|
|
|
|
def get_nearest_memo(self, query_text):
|
|
"""Retrieves the nearest memo to the given query text."""
|
|
results = self.vec_db.query(query_texts=[query_text], n_results=1)
|
|
uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
|
|
input_text_2, output_text = self.uid_text_dict[uid]
|
|
assert input_text == input_text_2
|
|
if self.verbosity >= 1:
|
|
print(
|
|
colored(
|
|
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
|
|
input_text, output_text, distance
|
|
),
|
|
"light_green",
|
|
)
|
|
)
|
|
return input_text, output_text, distance
|
|
|
|
def get_related_memos(self, query_text, n_results, threshold):
|
|
"""Retrieves memos that are related to the given query text within the specified distance threshold."""
|
|
if n_results > len(self.uid_text_dict):
|
|
n_results = len(self.uid_text_dict)
|
|
results = self.vec_db.query(query_texts=[query_text], n_results=n_results)
|
|
memos = []
|
|
num_results = len(results["ids"][0])
|
|
for i in range(num_results):
|
|
uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i]
|
|
if distance < threshold:
|
|
input_text_2, output_text = self.uid_text_dict[uid]
|
|
assert input_text == input_text_2
|
|
if self.verbosity >= 1:
|
|
print(
|
|
colored(
|
|
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
|
|
input_text, output_text, distance
|
|
),
|
|
"light_green",
|
|
)
|
|
)
|
|
memos.append((input_text, output_text, distance))
|
|
return memos
|
|
|
|
def prepopulate(self):
|
|
"""Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial."""
|
|
if self.verbosity >= 1:
|
|
print(colored("\nPREPOPULATING MEMORY", "light_green"))
|
|
examples = []
|
|
examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"})
|
|
examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"})
|
|
examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"})
|
|
examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"})
|
|
examples.append(
|
|
{"text": "To create a good PPT, include enough content to make it interesting.", "label": "yes"}
|
|
)
|
|
examples.append(
|
|
{
|
|
"text": "No, for this case the columns should be aspects and the rows should be frameworks.",
|
|
"label": "no",
|
|
}
|
|
)
|
|
examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"})
|
|
examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"})
|
|
examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"})
|
|
examples.append(
|
|
{
|
|
"text": "Double check to be sure that the columns in a table correspond to what was asked for.",
|
|
"label": "yes",
|
|
}
|
|
)
|
|
for example in examples:
|
|
self.add_input_output_pair(example["text"], example["label"])
|