diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml
index bb63bc665..0a678fd9e 100644
--- a/.github/workflows/openai.yml
+++ b/.github/workflows/openai.yml
@@ -53,6 +53,9 @@ jobs:
if: matrix.python-version == '3.9'
run: |
pip install -e .[retrievechat]
+ - name: Install packages for Teachable when needed
+ run: |
+ pip install -e .[teachable]
- name: Coverage
if: matrix.python-version == '3.9'
env:
diff --git a/.gitignore b/.gitignore
index 98517f9d6..479178234 100644
--- a/.gitignore
+++ b/.gitignore
@@ -164,3 +164,6 @@ key_openai.txt
key_aoai.txt
base_aoai.txt
wolfram.txt
+
+# DB on disk for TeachableAgent
+tmp/
diff --git a/README.md b/README.md
index 41fdb6c4d..9d425abd1 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
[](https://badge.fury.io/py/pyautogen)
[](https://github.com/microsoft/autogen/actions/workflows/python-package.yml)

+[](https://pepy.tech/project/pyautogen)
[](https://discord.gg/pAbnFJrkgZ)
This project is a spinoff from [FLAML](https://github.com/microsoft/FLAML).
@@ -137,13 +138,15 @@ In addition, you can find:
- [Research](https://microsoft.github.io/autogen/docs/Research), [blogposts](https://microsoft.github.io/autogen/blog) around AutoGen, and [Transparency FAQs](https://github.com/microsoft/autogen/blob/main/TRANSPARENCY_FAQS.md)
-- [Discord](https://discord.gg/pAbnFJrkgZ).
+- [Discord](https://discord.gg/pAbnFJrkgZ)
-- [Contributing guide](https://microsoft.github.io/autogen/docs/Contribute).
+- [Contributing guide](https://microsoft.github.io/autogen/docs/Contribute)
+
+- [Roadmap](https://github.com/orgs/microsoft/projects/989/views/3)
## Citation
-[AutoGen](https://arxiv.org/abs/2308.08155).
+[AutoGen](https://arxiv.org/abs/2308.08155)
```
@inproceedings{wu2023autogen,
@@ -156,7 +159,7 @@ In addition, you can find:
}
```
-[EcoOptiGen](https://arxiv.org/abs/2303.04673).
+[EcoOptiGen](https://arxiv.org/abs/2303.04673)
```
@inproceedings{wang2023EcoOptiGen,
@@ -167,7 +170,7 @@ In addition, you can find:
}
```
-[MathChat](https://arxiv.org/abs/2306.01337).
+[MathChat](https://arxiv.org/abs/2306.01337)
```
@inproceedings{wu2023empirical,
@@ -194,6 +197,11 @@ This project has adopted the [Microsoft Open Source Code of Conduct](https://ope
For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
+## Contributers Wall
+
+
+
+
# Legal Notices
Microsoft and any contributors grant you a license to the Microsoft documentation and other content
diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
index 94677244a..40a146e93 100644
--- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
+++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py
@@ -1,5 +1,9 @@
import re
-import chromadb
+
+try:
+ import chromadb
+except ImportError:
+ raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, num_tokens_from_text
diff --git a/autogen/agentchat/contrib/teachable_agent.py b/autogen/agentchat/contrib/teachable_agent.py
new file mode 100644
index 000000000..8db5b699e
--- /dev/null
+++ b/autogen/agentchat/contrib/teachable_agent.py
@@ -0,0 +1,425 @@
+import os
+from autogen import oai
+from autogen.agentchat.agent import Agent
+from autogen.agentchat.assistant_agent import ConversableAgent
+from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
+from typing import Callable, Dict, Optional, Union, List, Tuple, Any
+import chromadb
+from chromadb.config import Settings
+import pickle
+
+
+try:
+ from termcolor import colored
+except ImportError:
+
+ def colored(x, *args, **kwargs):
+ return x
+
+
+class TeachableAgent(ConversableAgent):
+ """Teachable Agent, a subclass of ConversableAgent using a vector database to remember user teachings.
+ In this class, the term 'user' refers to any caller (human or not) sending messages to this agent.
+ Not yet tested in the group-chat setting."""
+
+ def __init__(
+ self,
+ name="teachableagent",
+ system_message: Optional[
+ str
+ ] = "You are a helpful AI assistant that remembers user teachings from prior chats.",
+ human_input_mode: Optional[str] = "NEVER",
+ llm_config: Optional[Union[Dict, bool]] = None,
+ analyzer_llm_config: Optional[Union[Dict, bool]] = None,
+ teach_config: Optional[Dict] = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): name of the agent.
+ system_message (str): system message for the ChatCompletion inference.
+ human_input_mode (str): This agent should NEVER prompt the human for input.
+ llm_config (dict or False): llm inference configuration.
+ Please refer to [Completion.create](/docs/reference/oai/completion#create)
+ for available options.
+ To disable llm-based auto reply, set to False.
+ analyzer_llm_config (dict or False): llm inference configuration passed to TextAnalyzerAgent.
+ Given the default setting of None, TeachableAgent passes its own llm_config to TextAnalyzerAgent.
+ teach_config (dict or None): Additional parameters used by TeachableAgent.
+ To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
+ - verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
+ - reset_db (Optional, bool): True to clear the DB before starting. Default False.
+ - path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
+ - prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
+ - recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
+ - max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
+ **kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ human_input_mode=human_input_mode,
+ llm_config=llm_config,
+ **kwargs,
+ )
+ # Register a custom reply function.
+ self.register_reply(Agent, TeachableAgent._generate_teachable_assistant_reply, 1)
+
+ # Assemble the parameter settings.
+ self._teach_config = {} if teach_config is None else teach_config
+ self.verbosity = self._teach_config.get("verbosity", 0)
+ self.reset_db = self._teach_config.get("reset_db", False)
+ self.path_to_db_dir = self._teach_config.get("path_to_db_dir", "./tmp/teachable_agent_db")
+ self.prepopulate = self._teach_config.get("prepopulate", True)
+ self.recall_threshold = self._teach_config.get("recall_threshold", 1.5)
+ self.max_num_retrievals = self._teach_config.get("max_num_retrievals", 10)
+
+ # Create the analyzer.
+ if analyzer_llm_config is None:
+ analyzer_llm_config = llm_config
+ self.analyzer = TextAnalyzerAgent(llm_config=analyzer_llm_config)
+
+ # Create the memo store.
+ self.memo_store = MemoStore(self.verbosity, self.reset_db, self.path_to_db_dir)
+ self.user_comments = [] # Stores user comments until the end of each chat.
+
+ def close_db(self):
+ """Cleanly closes the memo store."""
+ self.memo_store.close()
+
+ def prepopulate_db(self):
+ """Adds a few arbitrary memos to the DB."""
+ self.memo_store.prepopulate()
+
+ def _generate_teachable_assistant_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[Any] = None, # Persistent state.
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """
+ Generates a reply to the last user message, after querying the memo store for relevant information.
+ Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
+ """
+ if self.llm_config is False:
+ raise ValueError("TeachableAgent requires self.llm_config to be set in its base class.")
+ if messages is None:
+ messages = self._oai_messages[sender] # In case of a direct call.
+
+ # Get the last user turn.
+ last_message = messages[-1]
+ user_text = last_message["content"]
+ if (not isinstance(user_text, str)) or ("context" in last_message):
+ raise ValueError(
+ "TeachableAgent currently assumes that the message content is a simple string. This error serves to flag a test case for relaxing this assumption."
+ )
+
+ # Keep track of this user turn as a potential source of memos later.
+ self.user_comments.append(user_text)
+
+ # Consider whether to retrieve something from the DB.
+ if self.memo_store.last_memo_id > 0:
+ new_user_text = self.consider_memo_retrieval(user_text)
+ if new_user_text != user_text:
+ # Make a copy of the message list, and replace the last user message with the new one.
+ messages = messages.copy()
+ messages[-1]["content"] = new_user_text
+
+ # Generate a response.
+ msgs = self._oai_system_message + messages
+ response = oai.ChatCompletion.create(messages=msgs, **self.llm_config)
+ response_text = oai.ChatCompletion.extract_text_or_function_call(response)[0]
+ return True, response_text
+
+ def learn_from_user_feedback(self):
+ """Reviews the user comments from the last chat, and decides what teachings to store as memos."""
+ print(colored("\nREVIEWING CHAT FOR USER TEACHINGS TO REMEMBER", "light_yellow"))
+ # Look at each user turn.
+ if len(self.user_comments) > 0:
+ for comment in self.user_comments:
+ # Consider whether to store something from this user turn in the DB.
+ self.consider_memo_storage(comment)
+ self.user_comments = []
+
+ def consider_memo_storage(self, comment):
+ """Decides whether to store something from one user comment in the DB."""
+ # Check for a problem-solution pair.
+ response = self.analyze(
+ comment,
+ "Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
+ )
+ if "yes" in response.lower():
+ # Can we extract advice?
+ advice = self.analyze(
+ comment,
+ "Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
+ )
+ if "none" not in advice.lower():
+ # Yes. Extract the task.
+ task = self.analyze(
+ comment,
+ "Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
+ )
+ # Generalize the task.
+ general_task = self.analyze(
+ task,
+ "Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
+ )
+ # Add the task-advice (problem-solution) pair to the vector DB.
+ if self.verbosity >= 1:
+ print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
+ self.memo_store.add_input_output_pair(general_task, advice)
+
+ # Check for information to be learned.
+ response = self.analyze(
+ comment,
+ "Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
+ )
+ if "yes" in response.lower():
+ # Yes. What question would this information answer?
+ question = self.analyze(
+ comment,
+ "Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
+ )
+ # Extract the information.
+ answer = self.analyze(
+ comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
+ )
+ # Add the question-answer pair to the vector DB.
+ if self.verbosity >= 1:
+ print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
+ self.memo_store.add_input_output_pair(question, answer)
+
+ def consider_memo_retrieval(self, comment):
+ """Decides whether to retrieve memos from the DB, and add them to the chat context."""
+
+ # First, use the user comment directly as the lookup key.
+ if self.verbosity >= 1:
+ print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
+ memo_list = self.retrieve_relevant_memos(comment)
+
+ # Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
+ response = self.analyze(
+ comment,
+ "Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
+ )
+ if "yes" in response.lower():
+ if self.verbosity >= 1:
+ print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
+ # Extract the task.
+ task = self.analyze(
+ comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
+ )
+ # Generalize the task.
+ general_task = self.analyze(
+ task,
+ "Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
+ )
+ # Append any relevant memos.
+ memo_list.extend(self.retrieve_relevant_memos(general_task))
+
+ # De-duplicate the memo list.
+ memo_list = list(set(memo_list))
+
+ # Append the memos to the last user message.
+ return comment + self.concatenate_memo_texts(memo_list)
+
+ def retrieve_relevant_memos(self, input_text):
+ """Returns semantically related memos from the DB."""
+ memo_list = self.memo_store.get_related_memos(
+ input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
+ )
+
+ if self.verbosity >= 1:
+ # Was anything retrieved?
+ if len(memo_list) == 0:
+ # No. Look at the closest memo.
+ print(colored("\nTHE CLOSEST MEMO IS BEYOND THE THRESHOLD:", "light_yellow"))
+ self.memo_store.get_nearest_memo(input_text)
+ print() # Print a blank line. The memo details were printed by get_nearest_memo().
+
+ # Create a list of just the memo output_text strings.
+ memo_list = [memo[1] for memo in memo_list]
+ return memo_list
+
+ def concatenate_memo_texts(self, memo_list):
+ """Concatenates the memo texts into a single string for inclusion in the chat context."""
+ memo_texts = ""
+ if len(memo_list) > 0:
+ info = "\n# Memories that might help\n"
+ for memo in memo_list:
+ info = info + "- " + memo + "\n"
+ if self.verbosity >= 1:
+ print(colored("\nMEMOS APPENDED TO LAST USER MESSAGE...\n" + info + "\n", "light_yellow"))
+ memo_texts = memo_texts + "\n" + info
+ return memo_texts
+
+ def analyze(self, text_to_analyze, analysis_instructions):
+ """Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
+ if self.verbosity >= 2:
+ # Use the messaging mechanism so that the analyzer's messages are included in the printed chat.
+ self.analyzer.reset() # Clear the analyzer's list of messages.
+ self.send(
+ recipient=self.analyzer, message=text_to_analyze, request_reply=False
+ ) # Put the message in the analyzer's list.
+ self.send(recipient=self.analyzer, message=analysis_instructions, request_reply=True) # Request the reply.
+ return self.last_message(self.analyzer)["content"]
+ else:
+ # Use the analyzer's method directly, to leave analyzer message out of the printed chat.
+ return self.analyzer.analyze_text(text_to_analyze, analysis_instructions)
+
+
+class MemoStore:
+ """
+ Provides memory storage and retrieval for a TeachableAgent, using a vector database.
+ Each DB entry (called a memo) is a pair of strings: an input text and an output text.
+ The input text might be a question, or a task to perform.
+ The output text might be an answer to the question, or advice on how to perform the task.
+ Vector embeddings are currently supplied by Chroma's default Sentence Transformers.
+ """
+
+ def __init__(self, verbosity, reset, path_to_db_dir):
+ """
+ Args:
+ - verbosity (Optional, int): 1 to print memory operations, 0 to omit them. 3+ to print memo lists.
+ - path_to_db_dir (Optional, str): path to the directory where the DB is stored.
+ """
+ self.verbosity = verbosity
+ self.reset = reset
+ self.path_to_db_dir = path_to_db_dir
+
+ # Load or create the vector DB on disk.
+ settings = Settings(
+ anonymized_telemetry=False, allow_reset=True, is_persistent=True, persist_directory=path_to_db_dir
+ )
+ self.db_client = chromadb.Client(settings)
+ self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
+ if reset:
+ self.reset_db()
+
+ # Load or create the associated memo dict on disk.
+ self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
+ self.uid_text_dict = {}
+ self.last_memo_id = 0
+ if (not reset) and os.path.exists(self.path_to_dict):
+ print(colored("\nLOADING MEMORY FROM DISK", "light_green"))
+ print(colored(" Location = {}".format(self.path_to_dict), "light_green"))
+ with open(self.path_to_dict, "rb") as f:
+ self.uid_text_dict = pickle.load(f)
+ self.last_memo_id = len(self.uid_text_dict)
+ if self.verbosity >= 3:
+ self.list_memos()
+
+ def list_memos(self):
+ """Prints the contents of MemoStore."""
+ print(colored("LIST OF MEMOS", "light_green"))
+ for uid, text in self.uid_text_dict.items():
+ input_text, output_text = text
+ print(
+ colored(
+ " ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text),
+ "light_green",
+ )
+ )
+
+ def close(self):
+ """Saves self.uid_text_dict to disk."""
+ print(colored("\nSAVING MEMORY TO DISK", "light_green"))
+ print(colored(" Location = {}".format(self.path_to_dict), "light_green"))
+ with open(self.path_to_dict, "wb") as file:
+ pickle.dump(self.uid_text_dict, file)
+
+ def reset_db(self):
+ """Forces immediate deletion of the DB's contents, in memory and on disk."""
+ print(colored("\nCLEARING MEMORY", "light_green"))
+ self.db_client.delete_collection("memos")
+ self.vec_db = self.db_client.create_collection("memos")
+ self.uid_text_dict = {}
+
+ def add_input_output_pair(self, input_text, output_text):
+ """Adds an input-output pair to the vector DB."""
+ self.last_memo_id += 1
+ self.vec_db.add(documents=[input_text], ids=[str(self.last_memo_id)])
+ self.uid_text_dict[str(self.last_memo_id)] = input_text, output_text
+ if self.verbosity >= 1:
+ print(
+ colored(
+ "\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}".format(
+ self.last_memo_id, input_text, output_text
+ ),
+ "light_green",
+ )
+ )
+ if self.verbosity >= 3:
+ self.list_memos()
+
+ def get_nearest_memo(self, query_text):
+ """Retrieves the nearest memo to the given query text."""
+ results = self.vec_db.query(query_texts=[query_text], n_results=1)
+ uid, input_text, distance = results["ids"][0][0], results["documents"][0][0], results["distances"][0][0]
+ input_text_2, output_text = self.uid_text_dict[uid]
+ assert input_text == input_text_2
+ if self.verbosity >= 1:
+ print(
+ colored(
+ "\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
+ input_text, output_text, distance
+ ),
+ "light_green",
+ )
+ )
+ return input_text, output_text, distance
+
+ def get_related_memos(self, query_text, n_results, threshold):
+ """Retrieves memos that are related to the given query text within the specified distance threshold."""
+ if n_results > len(self.uid_text_dict):
+ n_results = len(self.uid_text_dict)
+ results = self.vec_db.query(query_texts=[query_text], n_results=n_results)
+ memos = []
+ num_results = len(results["ids"][0])
+ for i in range(num_results):
+ uid, input_text, distance = results["ids"][0][i], results["documents"][0][i], results["distances"][0][i]
+ if distance < threshold:
+ input_text_2, output_text = self.uid_text_dict[uid]
+ assert input_text == input_text_2
+ if self.verbosity >= 1:
+ print(
+ colored(
+ "\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
+ input_text, output_text, distance
+ ),
+ "light_green",
+ )
+ )
+ memos.append((input_text, output_text, distance))
+ return memos
+
+ def prepopulate(self):
+ """Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial."""
+ if self.verbosity >= 1:
+ print(colored("\nPREPOPULATING MEMORY", "light_green"))
+ examples = []
+ examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"})
+ examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"})
+ examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"})
+ examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"})
+ examples.append(
+ {"text": "To create a good PPT, include enough content to make it interesting.", "label": "yes"}
+ )
+ examples.append(
+ {
+ "text": "No, for this case the columns should be aspects and the rows should be frameworks.",
+ "label": "no",
+ }
+ )
+ examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"})
+ examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"})
+ examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"})
+ examples.append(
+ {
+ "text": "Double check to be sure that the columns in a table correspond to what was asked for.",
+ "label": "yes",
+ }
+ )
+ for example in examples:
+ self.add_input_output_pair(example["text"], example["label"])
diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py
new file mode 100644
index 000000000..8cf88eba6
--- /dev/null
+++ b/autogen/agentchat/contrib/text_analyzer_agent.py
@@ -0,0 +1,82 @@
+from autogen import oai
+from autogen.agentchat.agent import Agent
+from autogen.agentchat.assistant_agent import ConversableAgent
+from typing import Callable, Dict, Optional, Union, List, Tuple, Any
+
+system_message = """You are an expert in text analysis.
+The user will give you TEXT to analyze.
+The user will give you analysis INSTRUCTIONS copied twice, at both the beginning and the end.
+You will follow these INSTRUCTIONS in analyzing the TEXT, then give the results of your expert analysis in the format requested."""
+
+
+class TextAnalyzerAgent(ConversableAgent):
+ """Text Analysis agent, a subclass of ConversableAgent designed to analyze text as instructed."""
+
+ def __init__(
+ self,
+ name="analyzer",
+ system_message: Optional[str] = system_message,
+ human_input_mode: Optional[str] = "NEVER",
+ llm_config: Optional[Union[Dict, bool]] = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ name (str): name of the agent.
+ system_message (str): system message for the ChatCompletion inference.
+ human_input_mode (str): This agent should NEVER prompt the human for input.
+ llm_config (dict or False): llm inference configuration.
+ Please refer to [Completion.create](/docs/reference/oai/completion#create)
+ for available options.
+ To disable llm-based auto reply, set to False.
+ teach_config (dict or None): Additional parameters used by TeachableAgent.
+ To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
+ - verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
+ - reset_db (Optional, bool): True to clear the DB before starting. Default False.
+ - path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
+ - prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
+ - recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
+ - max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
+ **kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
+ """
+ super().__init__(
+ name=name,
+ system_message=system_message,
+ human_input_mode=human_input_mode,
+ llm_config=llm_config,
+ **kwargs,
+ )
+ self.register_reply(Agent, TextAnalyzerAgent._analyze_in_reply, 1)
+
+ def _analyze_in_reply(
+ self,
+ messages: Optional[List[Dict]] = None,
+ sender: Optional[Agent] = None,
+ config: Optional[Any] = None,
+ ) -> Tuple[bool, Union[str, Dict, None]]:
+ """Analyzes the given text as instructed, and returns the analysis as a message.
+ Assumes exactly two messages containing the text to analyze and the analysis instructions.
+ See TeachableAgent.analyze for an example of how to use this method."""
+ if self.llm_config is False:
+ raise ValueError("TextAnalyzerAgent requires self.llm_config to be set in its base class.")
+ if messages is None:
+ messages = self._oai_messages[sender] # In case of a direct call.
+ assert len(messages) == 2
+
+ # Delegate to the analysis method.
+ return True, self.analyze_text(messages[0]["content"], messages[1]["content"])
+
+ def analyze_text(self, text_to_analyze, analysis_instructions):
+ """Analyzes the given text as instructed, and returns the analysis."""
+ # Assemble the message.
+ text_to_analyze = "# TEXT\n" + text_to_analyze + "\n"
+ analysis_instructions = "# INSTRUCTIONS\n" + analysis_instructions + "\n"
+ msg_text = "\n".join(
+ [analysis_instructions, text_to_analyze, analysis_instructions]
+ ) # Repeat the instructions.
+ messages = self._oai_system_message + [{"role": "user", "content": msg_text}]
+
+ # Generate and return the analysis string.
+ response = oai.ChatCompletion.create(context=None, messages=messages, **self.llm_config)
+ output_text = oai.ChatCompletion.extract_text_or_function_call(response)[0]
+ return output_text
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index c788c42db..12403d014 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -1131,3 +1131,12 @@ class ConversableAgent(Agent):
function_map: a dictionary mapping function names to functions.
"""
self._function_map.update(function_map)
+
+ def can_execute_function(self, name: str) -> bool:
+ """Whether the agent can execute the function."""
+ return name in self._function_map
+
+ @property
+ def function_map(self) -> Dict[str, Callable]:
+ """Return the function map."""
+ return self._function_map
diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py
index d2f53002b..9ed2ff774 100644
--- a/autogen/agentchat/groupchat.py
+++ b/autogen/agentchat/groupchat.py
@@ -10,12 +10,23 @@ logger = logging.getLogger(__name__)
@dataclass
class GroupChat:
- """A group chat class that contains a list of agents and the maximum number of rounds."""
+ """A group chat class that contains the following data fields:
+ - agents: a list of participating agents.
+ - messages: a list of messages in the group chat.
+ - max_round: the maximum number of rounds.
+ - admin_name: the name of the admin agent if there is one. Default is "Admin".
+ KeyBoardInterrupt will make the admin agent take over.
+ - func_call_filter: whether to enforce function call filter. Default is True.
+ When set to True and when a message is a function call suggestion,
+ the next speaker will be chosen from an agent which contains the corresponding function name
+ in its `function_map`.
+ """
agents: List[Agent]
messages: List[Dict]
max_round: int = 10
- admin_name: str = "Admin" # the name of the admin agent
+ admin_name: str = "Admin"
+ func_call_filter: bool = True
@property
def agent_names(self) -> List[str]:
@@ -30,45 +41,69 @@ class GroupChat:
"""Find the next speaker based on the message."""
return self.agents[self.agent_names.index(name)]
- def next_agent(self, agent: Agent) -> Agent:
+ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent:
"""Return the next agent in the list."""
- return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)]
+ if agents == self.agents:
+ return agents[(self.agent_names.index(agent.name) + 1) % len(agents)]
+ else:
+ offset = self.agent_names.index(agent.name) + 1
+ for i in range(len(self.agents)):
+ if self.agents[(offset + i) % len(self.agents)] in agents:
+ return self.agents[(offset + i) % len(self.agents)]
- def select_speaker_msg(self):
+ def select_speaker_msg(self, agents: List[Agent]):
"""Return the message for selecting the next speaker."""
return f"""You are in a role play game. The following roles are available:
{self._participant_roles()}.
Read the following conversation.
-Then select the next role from {self.agent_names} to play. Only return the role."""
+Then select the next role from {[agent.name for agent in agents]} to play. Only return the role."""
def select_speaker(self, last_speaker: Agent, selector: ConversableAgent):
"""Select the next speaker."""
- selector.update_system_message(self.select_speaker_msg())
-
- # Warn if GroupChat is underpopulated, without established changing behavior
- n_agents = len(self.agent_names)
- if n_agents < 3:
- logger.warning(
- f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
- )
-
+ if self.func_call_filter and self.messages and "function_call" in self.messages[-1]:
+ # find agents with the right function_map which contains the function name
+ agents = [
+ agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"])
+ ]
+ if len(agents) == 1:
+ # only one agent can execute the function
+ return agents[0]
+ elif not agents:
+ # find all the agents with function_map
+ agents = [agent for agent in self.agents if agent.function_map]
+ if len(agents) == 1:
+ return agents[0]
+ elif not agents:
+ raise ValueError(
+ f"No agent can execute the function {self.messages[-1]['name']}. "
+ "Please check the function_map of the agents."
+ )
+ else:
+ agents = self.agents
+ # Warn if GroupChat is underpopulated
+ n_agents = len(agents)
+ if n_agents < 3:
+ logger.warning(
+ f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient."
+ )
+ selector.update_system_message(self.select_speaker_msg(agents))
final, name = selector.generate_oai_reply(
self.messages
+ [
{
"role": "system",
- "content": f"Read the above conversation. Then select the next role from {self.agent_names} to play. Only return the role.",
+ "content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.",
}
]
)
if not final:
# i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id
- return self.next_agent(last_speaker)
+ return self.next_agent(last_speaker, agents)
try:
return self.agent_by_name(name)
except ValueError:
- return self.next_agent(last_speaker)
+ return self.next_agent(last_speaker, agents)
def _participant_roles(self):
return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents])
diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py
index 54739ec5d..a720ccc24 100644
--- a/autogen/oai/completion.py
+++ b/autogen/oai/completion.py
@@ -51,6 +51,7 @@ class Completion(openai_Completion):
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-35-turbo",
+ "gpt-35-turbo-16k",
"gpt-4",
"gpt-4-32k",
"gpt-4-32k-0314", # deprecate in Sep
@@ -69,11 +70,14 @@ class Completion(openai_Completion):
"text-davinci-002": 0.02,
"text-davinci-003": 0.02,
"gpt-3.5-turbo": (0.0015, 0.002),
+ "gpt-3.5-turbo-instruct": (0.0015, 0.002),
"gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
"gpt-3.5-turbo-0613": (0.0015, 0.002),
"gpt-3.5-turbo-16k": (0.003, 0.004),
"gpt-3.5-turbo-16k-0613": (0.003, 0.004),
- "gpt-35-turbo": 0.002,
+ "gpt-35-turbo": (0.0015, 0.002),
+ "gpt-35-turbo-16k": (0.003, 0.004),
+ "gpt-35-turbo-instruct": (0.0015, 0.002),
"gpt-4": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
diff --git a/autogen/version.py b/autogen/version.py
index 0c5c30071..3cb7d95ef 100644
--- a/autogen/version.py
+++ b/autogen/version.py
@@ -1 +1 @@
-__version__ = "0.1.11"
+__version__ = "0.1.13"
diff --git a/notebook/agentchat_groupchat_RAG.ipynb b/notebook/agentchat_groupchat_RAG.ipynb
new file mode 100644
index 000000000..fd12cbe8c
--- /dev/null
+++ b/notebook/agentchat_groupchat_RAG.ipynb
@@ -0,0 +1,1501 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Auto Generated Agent Chat: Group Chat with Retrieval Augmented Generation\n",
+ "\n",
+ "AutoGen supports conversable agents powered by LLMs, tools or humans, performing tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation.\n",
+ "Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
+ "\n",
+ "## Requirements\n",
+ "\n",
+ "AutoGen requires `Python>=3.8`. To run this notebook example, please install:\n",
+ "```bash\n",
+ "pip install pyautogen\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-stderr\n",
+ "# %pip install pyautogen[retrievechat]~=0.1.11"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Set your API Endpoint\n",
+ "\n",
+ "The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LLM models: ['gpt-35-turbo', 'gpt-35-turbo-0613']\n"
+ ]
+ }
+ ],
+ "source": [
+ "import autogen\n",
+ "\n",
+ "config_list = autogen.config_list_from_json(\n",
+ " \"OAI_CONFIG_LIST\",\n",
+ " file_location=\".\",\n",
+ " filter_dict={\n",
+ " \"model\": [\"gpt-3.5-turbo\", \"gpt-35-turbo\", \"gpt-35-turbo-0613\", \"gpt-4\", \"gpt4\", \"gpt-4-32k\"],\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "print(\"LLM models: \", [config_list[i][\"model\"] for i in range(len(config_list))])"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well).\n",
+ "\n",
+ "The config list looks like the following:\n",
+ "```python\n",
+ "config_list = [\n",
+ " {\n",
+ " \"model\": \"gpt-4\",\n",
+ " \"api_key\": \"\",\n",
+ " }, # OpenAI API endpoint for gpt-4\n",
+ " {\n",
+ " \"engine\": \"gpt-35-turbo-0631\", \n",
+ " \"model\": \"gpt-35-turbo-0631\", # 0631 or newer is needed to use functions\n",
+ " \"api_base\": \"\", \n",
+ " \"api_type\": \"azure\", \n",
+ " \"api_version\": \"2023-07-01-preview\", # 2023-07-01-preview or newer is needed to use functions\n",
+ " \"api_key\": \"\"\n",
+ " }\n",
+ "]\n",
+ "```\n",
+ "\n",
+ "If you open this notebook in colab, you can upload your files by clicking the file icon on the left panel and then choose \"upload file\" icon.\n",
+ "\n",
+ "You can set the value of config_list in other ways you prefer, e.g., loading from a YAML file."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Construct Agents"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent\n",
+ "from autogen import AssistantAgent\n",
+ "import chromadb\n",
+ "\n",
+ "llm_config = {\n",
+ " \"request_timeout\": 60,\n",
+ " \"seed\": 42,\n",
+ " \"config_list\": config_list,\n",
+ " \"temperature\": 0,\n",
+ "}\n",
+ "\n",
+ "autogen.ChatCompletion.start_logging()\n",
+ "termination_msg = lambda x: isinstance(x, dict) and \"TERMINATE\" == str(x.get(\"content\", \"\"))[-9:].upper()\n",
+ "\n",
+ "boss = autogen.UserProxyAgent(\n",
+ " name=\"Boss\",\n",
+ " is_termination_msg=termination_msg,\n",
+ " human_input_mode=\"TERMINATE\",\n",
+ " system_message=\"The boss who ask questions and give tasks.\",\n",
+ " code_execution_config=False, # we don't want to execute code in this case.\n",
+ ")\n",
+ "\n",
+ "boss_aid = RetrieveUserProxyAgent(\n",
+ " name=\"Boss_Assistant\",\n",
+ " is_termination_msg=termination_msg,\n",
+ " system_message=\"Assistant who has extra content retrieval power for solving difficult problems.\",\n",
+ " human_input_mode=\"TERMINATE\",\n",
+ " max_consecutive_auto_reply=3,\n",
+ " retrieve_config={\n",
+ " \"task\": \"code\",\n",
+ " \"docs_path\": \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n",
+ " \"chunk_token_size\": 1000,\n",
+ " \"model\": config_list[0][\"model\"],\n",
+ " \"client\": chromadb.PersistentClient(path=\"/tmp/chromadb\"),\n",
+ " \"collection_name\": \"groupchat\",\n",
+ " \"get_or_create\": True,\n",
+ " },\n",
+ " code_execution_config=False, # we don't want to execute code in this case.\n",
+ ")\n",
+ "\n",
+ "coder = AssistantAgent(\n",
+ " name=\"Senior_Python_Engineer\",\n",
+ " is_termination_msg=termination_msg,\n",
+ " system_message=\"You are a senior python engineer. Reply `TERMINATE` in the end when everything is done.\",\n",
+ " llm_config=llm_config,\n",
+ ")\n",
+ "\n",
+ "pm = autogen.AssistantAgent(\n",
+ " name=\"Product_Manager\",\n",
+ " is_termination_msg=termination_msg,\n",
+ " system_message=\"You are a product manager. Reply `TERMINATE` in the end when everything is done.\",\n",
+ " llm_config=llm_config,\n",
+ ")\n",
+ "\n",
+ "reviewer = autogen.AssistantAgent(\n",
+ " name=\"Code_Reviewer\",\n",
+ " is_termination_msg=termination_msg,\n",
+ " system_message=\"You are a code reviewer. Reply `TERMINATE` in the end when everything is done.\",\n",
+ " llm_config=llm_config,\n",
+ ")\n",
+ "\n",
+ "PROBLEM = \"How to use spark for parallel training in FLAML? Give me sample code.\"\n",
+ "\n",
+ "def _reset_agents():\n",
+ " boss.reset()\n",
+ " boss_aid.reset()\n",
+ " coder.reset()\n",
+ " pm.reset()\n",
+ " reviewer.reset()\n",
+ "\n",
+ "def rag_chat():\n",
+ " _reset_agents()\n",
+ " groupchat = autogen.GroupChat(\n",
+ " agents=[boss_aid, coder, pm, reviewer], messages=[], max_round=12\n",
+ " )\n",
+ " manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)\n",
+ "\n",
+ " # Start chatting with boss_aid as this is the user proxy agent.\n",
+ " boss_aid.initiate_chat(\n",
+ " manager,\n",
+ " problem=PROBLEM,\n",
+ " n_results=3,\n",
+ " )\n",
+ "\n",
+ "def norag_chat():\n",
+ " _reset_agents()\n",
+ " groupchat = autogen.GroupChat(\n",
+ " agents=[boss, coder, pm, reviewer], messages=[], max_round=12\n",
+ " )\n",
+ " manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)\n",
+ "\n",
+ " # Start chatting with boss as this is the user proxy agent.\n",
+ " boss.initiate_chat(\n",
+ " manager,\n",
+ " message=PROBLEM,\n",
+ " )\n",
+ "\n",
+ "def call_rag_chat():\n",
+ " _reset_agents()\n",
+ " # In this case, we will have multiple user proxy agents and we don't initiate the chat\n",
+ " # with RAG user proxy agent.\n",
+ " # In order to use RAG user proxy agent, we need to wrap RAG agents in a function and call\n",
+ " # it from other agents.\n",
+ " def retrieve_content(message, n_results=3):\n",
+ " boss_aid.n_results = n_results # Set the number of results to be retrieved.\n",
+ " # Check if we need to update the context.\n",
+ " update_context_case1, update_context_case2 = boss_aid._check_update_context(message)\n",
+ " if (update_context_case1 or update_context_case2) and boss_aid.update_context:\n",
+ " boss_aid.problem = message if not hasattr(boss_aid, \"problem\") else boss_aid.problem\n",
+ " _, ret_msg = boss_aid._generate_retrieve_user_reply(message)\n",
+ " else:\n",
+ " ret_msg = boss_aid.generate_init_message(message, n_results=n_results)\n",
+ " return ret_msg if ret_msg else message\n",
+ " \n",
+ " boss_aid.human_input_mode = \"NEVER\" # Disable human input for boss_aid since it only retrieves content.\n",
+ " \n",
+ " llm_config = {\n",
+ " \"functions\": [\n",
+ " {\n",
+ " \"name\": \"retrieve_content\",\n",
+ " \"description\": \"retrieve content for code generation and question answering.\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"message\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"Refined message which keeps the original meaning and can be used to retrieve content for code generation and question answering.\",\n",
+ " }\n",
+ " },\n",
+ " \"required\": [\"message\"],\n",
+ " },\n",
+ " },\n",
+ " ],\n",
+ " \"config_list\": config_list,\n",
+ " \"request_timeout\": 60,\n",
+ " \"seed\": 42,\n",
+ " }\n",
+ "\n",
+ " for agent in [coder, pm, reviewer]:\n",
+ " # update llm_config for assistant agents.\n",
+ " agent.llm_config.update(llm_config)\n",
+ "\n",
+ " for agent in [boss, coder, pm, reviewer]:\n",
+ " # register functions for all agents.\n",
+ " agent.register_function(\n",
+ " function_map={\n",
+ " \"retrieve_content\": retrieve_content,\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " groupchat = autogen.GroupChat(\n",
+ " agents=[boss, coder, pm, reviewer], messages=[], max_round=12\n",
+ " )\n",
+ " manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)\n",
+ "\n",
+ " # Start chatting with boss as this is the user proxy agent.\n",
+ " boss.initiate_chat(\n",
+ " manager,\n",
+ " message=PROBLEM,\n",
+ " )"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Start Chat\n",
+ "\n",
+ "### UserProxyAgent doesn't get the correct code\n",
+ "[FLAML](https://github.com/microsoft/FLAML) was open sourced in 2020, so ChatGPT is familiar with it. However, Spark-related APIs were added in 2022, so they were not in ChatGPT's training data. As a result, we end up with invalid code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "How to use spark for parallel training in FLAML? Give me sample code.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "How to use spark for parallel training in FLAML? Give me sample code.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "To use Spark for parallel training in FLAML, you can use the `SparkTrials` class provided by FLAML. Here is a sample code:\n",
+ "\n",
+ "```python\n",
+ "from flaml import AutoML\n",
+ "from flaml.data import load_credit\n",
+ "from flaml.model import SparkTrials\n",
+ "\n",
+ "# Load data\n",
+ "X_train, y_train, X_test, y_test = load_credit()\n",
+ "\n",
+ "# Define the search space\n",
+ "search_space = {\n",
+ " \"n_estimators\": {\"domain\": range(10, 100)},\n",
+ " \"max_depth\": {\"domain\": range(6, 10)},\n",
+ " \"learning_rate\": {\"domain\": (0.01, 0.1, 1)},\n",
+ "}\n",
+ "\n",
+ "# Create an AutoML instance with SparkTrials\n",
+ "automl = AutoML(\n",
+ " search_space=search_space,\n",
+ " task=\"classification\",\n",
+ " n_jobs=1,\n",
+ " ensemble_size=0,\n",
+ " max_time=60,\n",
+ " trials=SparkTrials(parallelism=2),\n",
+ ")\n",
+ "\n",
+ "# Train the model\n",
+ "automl.fit(X_train=X_train, y_train=y_train)\n",
+ "\n",
+ "# Evaluate the model\n",
+ "print(\"Best model:\", automl.best_model)\n",
+ "print(\"Best hyperparameters:\", automl.best_config)\n",
+ "print(\"Test accuracy:\", automl.score(X_test, y_test))\n",
+ "\n",
+ "# Terminate\n",
+ "TERMINATE\n",
+ "```\n",
+ "\n",
+ "In this code, we first load the credit dataset. Then, we define the search space for the hyperparameters. We create an `AutoML` instance with `SparkTrials` as the `trials` parameter. We set the `parallelism` parameter to 2 to use 2 Spark workers for parallel training. Finally, we fit the model and evaluate it.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "Great! This code looks good to me.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n",
+ "\n",
+ "Thank you! Let me know if you have any other questions.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "norag_chat()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### RetrieveUserProxyAgent get the correct code\n",
+ "Since RetrieveUserProxyAgent can perform retrieval-augmented generation based on the given documentation file, ChatGPT can generate the correct code for us!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Trying to create collection.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:autogen.retrieve_utils:Found 2 chunks.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "doc_ids: [['doc_0', 'doc_1', 'doc_4']]\n",
+ "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n",
+ "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n",
+ "\u001b[32mAdding doc_id doc_4 to context.\u001b[0m\n",
+ "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n",
+ "\n",
+ "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n",
+ "context provided by the user.\n",
+ "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n",
+ "For code generation, you must obey the following rules:\n",
+ "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n",
+ "Rule 2. You must follow the formats below to write your code:\n",
+ "```language\n",
+ "# your code\n",
+ "```\n",
+ "\n",
+ "User's question is: How to use spark for parallel training in FLAML? Give me sample code.\n",
+ "\n",
+ "Context is: # Integrate - Spark\n",
+ "\n",
+ "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n",
+ "- Use Spark ML estimators for AutoML.\n",
+ "- Use Spark to run training in parallel spark jobs.\n",
+ "\n",
+ "## Spark ML Estimators\n",
+ "\n",
+ "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n",
+ "\n",
+ "### Data\n",
+ "\n",
+ "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n",
+ "\n",
+ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n",
+ "\n",
+ "This function also accepts optional arguments `index_col` and `default_index_type`.\n",
+ "- `index_col` is the column name to use as the index, default is None.\n",
+ "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n",
+ "\n",
+ "Here is an example code snippet for Spark Data:\n",
+ "\n",
+ "```python\n",
+ "import pandas as pd\n",
+ "from flaml.automl.spark.utils import to_pandas_on_spark\n",
+ "# Creating a dictionary\n",
+ "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
+ " \"Age_Years\": [20, 15, 10, 7, 25],\n",
+ " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n",
+ "\n",
+ "# Creating a pandas DataFrame\n",
+ "dataframe = pd.DataFrame(data)\n",
+ "label = \"Price\"\n",
+ "\n",
+ "# Convert to pandas-on-spark dataframe\n",
+ "psdf = to_pandas_on_spark(dataframe)\n",
+ "```\n",
+ "\n",
+ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n",
+ "\n",
+ "Here is an example of how to use it:\n",
+ "```python\n",
+ "from pyspark.ml.feature import VectorAssembler\n",
+ "columns = psdf.columns\n",
+ "feature_cols = [col for col in columns if col != label]\n",
+ "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
+ "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n",
+ "```\n",
+ "\n",
+ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n",
+ "\n",
+ "### Estimators\n",
+ "#### Model List\n",
+ "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n",
+ "\n",
+ "#### Usage\n",
+ "First, prepare your data in the required format as described in the previous section.\n",
+ "\n",
+ "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n",
+ "\n",
+ "Here is an example code snippet using SparkML models in AutoML:\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "# prepare your data in pandas-on-spark format as we previously mentioned\n",
+ "\n",
+ "automl = flaml.AutoML()\n",
+ "settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n",
+ " \"task\": \"regression\",\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=psdf,\n",
+ " label=label,\n",
+ " **settings,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "\n",
+ "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n",
+ "\n",
+ "## Parallel Spark Jobs\n",
+ "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n",
+ "\n",
+ "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n",
+ "\n",
+ "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n",
+ "\n",
+ "\n",
+ "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n",
+ "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n",
+ "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n",
+ "\n",
+ "An example code snippet for using parallel Spark jobs:\n",
+ "```python\n",
+ "import flaml\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"n_concurrent_trials\": 2,\n",
+ " \"use_spark\": True,\n",
+ " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=dataframe,\n",
+ " label=label,\n",
+ " **automl_settings,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "\n",
+ "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "# for flaml.tune\n",
+ "with mlflow.start_run(run_name=f\"spark_auto_trials_1686631558\"):\n",
+ " analysis = flaml.tune.run(\n",
+ " func_to_tune,\n",
+ " params,\n",
+ " metric=\"r2\",\n",
+ " mode=\"max\",\n",
+ " mlflow_exp_name=\"test_doc\",\n",
+ " use_spark=True,\n",
+ " )\n",
+ "\n",
+ "# for flaml.automl\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"use_spark\": True,\n",
+ " \"mlflow_exp_name\": \"test_doc\",\n",
+ " \"estimator_list\": [\n",
+ " \"lgbm\",\n",
+ " \"rf\",\n",
+ " \"xgboost\",\n",
+ " \"extra_tree\",\n",
+ " \"xgb_limitdepth\",\n",
+ " ], # catboost does not yet support mlflow autologging\n",
+ "}\n",
+ "with mlflow.start_run(run_name=f\"automl_spark_trials_1686631579\"):\n",
+ " automl_experiment.fit(X_train=train_x, y_train=train_y, **automl_settings)\n",
+ "```\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Results\n",
+ "*Tune Autolog Trials on MLFlow UI*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "*AutoML Autolog Trials on MLFlow UI*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Differences Between Auto and Manual Logging\n",
+ "Autologging is managed by MLFlow, while manual logging is maintained by FLAML.\n",
+ "\n",
+ "\n",
+ "#### Details of Manual Logging\n",
+ "FLAML logs general artifacts for AutoML tasks. Specifically, we log these artifacts:\n",
+ "\n",
+ "**`flaml.tune`**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "- We create a parent run to log the best metric and the best configuration for the entire tuning process.\n",
+ "- For each trial, we create a child run to log the metric specific to the tune function and the configuration for that trial.\n",
+ "\n",
+ "**`flaml.automl`**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "- We create a parent run to log the results of the experiment. This includes:\n",
+ " - The configuration of this model.\n",
+ " - The `best_validation_loss` produced by this model.\n",
+ " - The `best_iteration` to identify the point at which this model was found.\n",
+ "- For each state (a specific learner with different hyperparameters), we record the best trial for this model. This includes:\n",
+ " - The configuration of the best trial.\n",
+ " - The `validation_loss` the best trial produces.\n",
+ " - The `iter_count` to identify how many trials we have conducted for this state.\n",
+ " - The `pred_time`, which is the time cost of predicting test data for this model.\n",
+ " - The `wall_clock_time`, which is the time cost of this state.\n",
+ " - The `sample_size` to show how much data we sampled in this state.\n",
+ "Note that we also added these information to autolog AutoML run.\n",
+ "\n",
+ "\n",
+ "#### Details of Autologging\n",
+ "Autolog artifacts typically include model parameters, model files, and runtime metrics like the following:\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Artifacts can differ among various machine learning libraries. More detailed information can be found [here](https://mlflow.org/docs/latest/tracking.html#automatic-logging).\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "## Plot Experiment Result\n",
+ "The `flaml.visualization` module provides utility functions for plotting the optimization process using [plotly](https://plotly.com/python/). Leveraging `plotly`, users can interactively explore experiment results. To use these plotting functions, simply provide your optimized `flaml.AutoML` or `flaml.tune.tune.ExperimentAnalysis` object as input. Optional parameters can be added using keyword arguments.\n",
+ "\n",
+ "Avaliable plotting functions:\n",
+ "- `plot_optimization_history`: Plot optimization history of all trials in the experiment.\n",
+ "- `plot_feature_importance`: Plot importance for each feature in the dataset.\n",
+ "- `plot_parallel_coordinate`: Plot the high-dimensional parameter relationships in the experiment.\n",
+ "- `plot_contour`: Plot the parameter relationship as contour plot in the experiment.\n",
+ "- `plot_edf`: Plot the objective value EDF (empirical distribution function) of the experiment.\n",
+ "- `plot_timeline`: Plot the timeline of the experiment.\n",
+ "- `plot_slice`: Plot the parameter relationship as slice plot in a study.\n",
+ "\n",
+ "### Figure Examples\n",
+ "\n",
+ "\n",
+ "Check out our example [notebook](../../notebook/trident/automl_plot.ipynb) for a preview of all interactive plots.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n",
+ "\u001b[32mAdding doc_id doc_4 to context.\u001b[0m\n",
+ "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n",
+ "\n",
+ "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n",
+ "context provided by the user.\n",
+ "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n",
+ "For code generation, you must obey the following rules:\n",
+ "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n",
+ "Rule 2. You must follow the formats below to write your code:\n",
+ "```language\n",
+ "# your code\n",
+ "```\n",
+ "\n",
+ "User's question is: How to use spark for parallel training in FLAML? Give me sample code.\n",
+ "\n",
+ "Context is: # Integrate - Spark\n",
+ "\n",
+ "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n",
+ "- Use Spark ML estimators for AutoML.\n",
+ "- Use Spark to run training in parallel spark jobs.\n",
+ "\n",
+ "## Spark ML Estimators\n",
+ "\n",
+ "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n",
+ "\n",
+ "### Data\n",
+ "\n",
+ "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n",
+ "\n",
+ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n",
+ "\n",
+ "This function also accepts optional arguments `index_col` and `default_index_type`.\n",
+ "- `index_col` is the column name to use as the index, default is None.\n",
+ "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n",
+ "\n",
+ "Here is an example code snippet for Spark Data:\n",
+ "\n",
+ "```python\n",
+ "import pandas as pd\n",
+ "from flaml.automl.spark.utils import to_pandas_on_spark\n",
+ "# Creating a dictionary\n",
+ "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
+ " \"Age_Years\": [20, 15, 10, 7, 25],\n",
+ " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n",
+ "\n",
+ "# Creating a pandas DataFrame\n",
+ "dataframe = pd.DataFrame(data)\n",
+ "label = \"Price\"\n",
+ "\n",
+ "# Convert to pandas-on-spark dataframe\n",
+ "psdf = to_pandas_on_spark(dataframe)\n",
+ "```\n",
+ "\n",
+ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n",
+ "\n",
+ "Here is an example of how to use it:\n",
+ "```python\n",
+ "from pyspark.ml.feature import VectorAssembler\n",
+ "columns = psdf.columns\n",
+ "feature_cols = [col for col in columns if col != label]\n",
+ "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
+ "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n",
+ "```\n",
+ "\n",
+ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n",
+ "\n",
+ "### Estimators\n",
+ "#### Model List\n",
+ "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n",
+ "\n",
+ "#### Usage\n",
+ "First, prepare your data in the required format as described in the previous section.\n",
+ "\n",
+ "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n",
+ "\n",
+ "Here is an example code snippet using SparkML models in AutoML:\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "# prepare your data in pandas-on-spark format as we previously mentioned\n",
+ "\n",
+ "automl = flaml.AutoML()\n",
+ "settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n",
+ " \"task\": \"regression\",\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=psdf,\n",
+ " label=label,\n",
+ " **settings,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "\n",
+ "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n",
+ "\n",
+ "## Parallel Spark Jobs\n",
+ "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n",
+ "\n",
+ "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n",
+ "\n",
+ "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n",
+ "\n",
+ "\n",
+ "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n",
+ "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n",
+ "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n",
+ "\n",
+ "An example code snippet for using parallel Spark jobs:\n",
+ "```python\n",
+ "import flaml\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"n_concurrent_trials\": 2,\n",
+ " \"use_spark\": True,\n",
+ " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=dataframe,\n",
+ " label=label,\n",
+ " **automl_settings,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "\n",
+ "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "# for flaml.tune\n",
+ "with mlflow.start_run(run_name=f\"spark_auto_trials_1686631558\"):\n",
+ " analysis = flaml.tune.run(\n",
+ " func_to_tune,\n",
+ " params,\n",
+ " metric=\"r2\",\n",
+ " mode=\"max\",\n",
+ " mlflow_exp_name=\"test_doc\",\n",
+ " use_spark=True,\n",
+ " )\n",
+ "\n",
+ "# for flaml.automl\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"use_spark\": True,\n",
+ " \"mlflow_exp_name\": \"test_doc\",\n",
+ " \"estimator_list\": [\n",
+ " \"lgbm\",\n",
+ " \"rf\",\n",
+ " \"xgboost\",\n",
+ " \"extra_tree\",\n",
+ " \"xgb_limitdepth\",\n",
+ " ], # catboost does not yet support mlflow autologging\n",
+ "}\n",
+ "with mlflow.start_run(run_name=f\"automl_spark_trials_1686631579\"):\n",
+ " automl_experiment.fit(X_train=train_x, y_train=train_y, **automl_settings)\n",
+ "```\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Results\n",
+ "*Tune Autolog Trials on MLFlow UI*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "*AutoML Autolog Trials on MLFlow UI*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Differences Between Auto and Manual Logging\n",
+ "Autologging is managed by MLFlow, while manual logging is maintained by FLAML.\n",
+ "\n",
+ "\n",
+ "#### Details of Manual Logging\n",
+ "FLAML logs general artifacts for AutoML tasks. Specifically, we log these artifacts:\n",
+ "\n",
+ "**`flaml.tune`**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "- We create a parent run to log the best metric and the best configuration for the entire tuning process.\n",
+ "- For each trial, we create a child run to log the metric specific to the tune function and the configuration for that trial.\n",
+ "\n",
+ "**`flaml.automl`**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "- We create a parent run to log the results of the experiment. This includes:\n",
+ " - The configuration of this model.\n",
+ " - The `best_validation_loss` produced by this model.\n",
+ " - The `best_iteration` to identify the point at which this model was found.\n",
+ "- For each state (a specific learner with different hyperparameters), we record the best trial for this model. This includes:\n",
+ " - The configuration of the best trial.\n",
+ " - The `validation_loss` the best trial produces.\n",
+ " - The `iter_count` to identify how many trials we have conducted for this state.\n",
+ " - The `pred_time`, which is the time cost of predicting test data for this model.\n",
+ " - The `wall_clock_time`, which is the time cost of this state.\n",
+ " - The `sample_size` to show how much data we sampled in this state.\n",
+ "Note that we also added these information to autolog AutoML run.\n",
+ "\n",
+ "\n",
+ "#### Details of Autologging\n",
+ "Autolog artifacts typically include model parameters, model files, and runtime metrics like the following:\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Artifacts can differ among various machine learning libraries. More detailed information can be found [here](https://mlflow.org/docs/latest/tracking.html#automatic-logging).\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "## Plot Experiment Result\n",
+ "The `flaml.visualization` module provides utility functions for plotting the optimization process using [plotly](https://plotly.com/python/). Leveraging `plotly`, users can interactively explore experiment results. To use these plotting functions, simply provide your optimized `flaml.AutoML` or `flaml.tune.tune.ExperimentAnalysis` object as input. Optional parameters can be added using keyword arguments.\n",
+ "\n",
+ "Avaliable plotting functions:\n",
+ "- `plot_optimization_history`: Plot optimization history of all trials in the experiment.\n",
+ "- `plot_feature_importance`: Plot importance for each feature in the dataset.\n",
+ "- `plot_parallel_coordinate`: Plot the high-dimensional parameter relationships in the experiment.\n",
+ "- `plot_contour`: Plot the parameter relationship as contour plot in the experiment.\n",
+ "- `plot_edf`: Plot the objective value EDF (empirical distribution function) of the experiment.\n",
+ "- `plot_timeline`: Plot the timeline of the experiment.\n",
+ "- `plot_slice`: Plot the parameter relationship as slice plot in a study.\n",
+ "\n",
+ "### Figure Examples\n",
+ "\n",
+ "\n",
+ "Check out our example [notebook](../../notebook/trident/automl_plot.ipynb) for a preview of all interactive plots.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "To use Spark for parallel training in FLAML, you can activate Spark as the parallel backend during parallel tuning in both AutoML and Hyperparameter Tuning, by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using `joblib-spark`. \n",
+ "\n",
+ "Here is an example code snippet for using parallel Spark jobs:\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"use_spark\": True,\n",
+ " \"estimator_list\": [\n",
+ " \"lgbm\",\n",
+ " \"rf\",\n",
+ " \"xgboost\",\n",
+ " \"extra_tree\",\n",
+ " \"xgb_limitdepth\",\n",
+ " ],\n",
+ "}\n",
+ "automl_experiment.fit(X_train=train_x, y_train=train_y, **automl_settings)\n",
+ "```\n",
+ "\n",
+ "Note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n",
+ "\n",
+ "You can also use Spark ML estimators for AutoML. FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n",
+ "\n",
+ "Here is an example code snippet for Spark Data:\n",
+ "\n",
+ "```python\n",
+ "import pandas as pd\n",
+ "from flaml.automl.spark.utils import to_pandas_on_spark\n",
+ "# Creating a dictionary\n",
+ "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
+ " \"Age_Years\": [20, 15, 10, 7, 25],\n",
+ " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n",
+ "\n",
+ "# Creating a pandas DataFrame\n",
+ "dataframe = pd.DataFrame(data)\n",
+ "label = \"Price\"\n",
+ "\n",
+ "# Convert to pandas-on-spark dataframe\n",
+ "psdf = to_pandas_on_spark(dataframe)\n",
+ "```\n",
+ "\n",
+ "To use Spark ML models you need to format your data appropriately. Specifically, use `VectorAssembler` to merge all feature columns into a single vector column.\n",
+ "\n",
+ "Here is an example of how to use it:\n",
+ "```python\n",
+ "from pyspark.ml.feature import VectorAssembler\n",
+ "columns = psdf.columns\n",
+ "feature_cols = [col for col in columns if col != label]\n",
+ "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
+ "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n",
+ "```\n",
+ "\n",
+ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n",
+ "\n",
+ "You can also plot the optimization process using `plotly` by providing your optimized `flaml.AutoML` or `flaml.tune.tune.ExperimentAnalysis` object as input. Optional parameters can be added using keyword arguments. Available plotting functions include `plot_optimization_history`, `plot_feature_importance`, `plot_parallel_coordinate`, `plot_contour`, `plot_edf`, `plot_timeline`, and `plot_slice`.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "Is there anything else you need help with?\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n",
+ "\n",
+ "No, that's all. Thank you for your help!\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "You're welcome! Don't hesitate to ask if you have any more questions in the future. Have a great day!\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "Have a great day too!\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n",
+ "\n",
+ "Thank you!\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "You're welcome!\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss_Assistant\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "rag_chat()\n",
+ "# type exit to terminate the chat"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Call RetrieveUserProxyAgent while init chat with another user proxy agent\n",
+ "Sometimes, there might be a need to use RetrieveUserProxyAgent in group chat without initializing the chat with it. In such scenarios, it becomes essential to create a function that wraps the RAG agents and allows them to be called from other agents."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "How to use spark for parallel training in FLAML? Give me sample code.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "How to use spark for parallel training in FLAML? Give me sample code.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\u001b[32m***** Suggested function Call: retrieve_content *****\u001b[0m\n",
+ "Arguments: \n",
+ "{\n",
+ " \"message\": \"How to use spark for parallel training in FLAML?\"\n",
+ "}\n",
+ "\u001b[32m*****************************************************\u001b[0m\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[35m\n",
+ ">>>>>>>> EXECUTING FUNCTION retrieve_content...\u001b[0m\n",
+ "doc_ids: [['doc_0', 'doc_1', 'doc_4']]\n",
+ "\u001b[32mAdding doc_id doc_0 to context.\u001b[0m\n",
+ "\u001b[32mAdding doc_id doc_1 to context.\u001b[0m\n",
+ "\u001b[32mAdding doc_id doc_4 to context.\u001b[0m\n",
+ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\u001b[32m***** Response from calling function \"retrieve_content\" *****\u001b[0m\n",
+ "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n",
+ "context provided by the user.\n",
+ "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n",
+ "For code generation, you must obey the following rules:\n",
+ "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n",
+ "Rule 2. You must follow the formats below to write your code:\n",
+ "```language\n",
+ "# your code\n",
+ "```\n",
+ "\n",
+ "User's question is: How to use spark for parallel training in FLAML?\n",
+ "\n",
+ "Context is: # Integrate - Spark\n",
+ "\n",
+ "FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n",
+ "- Use Spark ML estimators for AutoML.\n",
+ "- Use Spark to run training in parallel spark jobs.\n",
+ "\n",
+ "## Spark ML Estimators\n",
+ "\n",
+ "FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n",
+ "\n",
+ "### Data\n",
+ "\n",
+ "For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n",
+ "\n",
+ "This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n",
+ "\n",
+ "This function also accepts optional arguments `index_col` and `default_index_type`.\n",
+ "- `index_col` is the column name to use as the index, default is None.\n",
+ "- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n",
+ "\n",
+ "Here is an example code snippet for Spark Data:\n",
+ "\n",
+ "```python\n",
+ "import pandas as pd\n",
+ "from flaml.automl.spark.utils import to_pandas_on_spark\n",
+ "# Creating a dictionary\n",
+ "data = {\"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
+ " \"Age_Years\": [20, 15, 10, 7, 25],\n",
+ " \"Price\": [100000, 200000, 300000, 240000, 120000]}\n",
+ "\n",
+ "# Creating a pandas DataFrame\n",
+ "dataframe = pd.DataFrame(data)\n",
+ "label = \"Price\"\n",
+ "\n",
+ "# Convert to pandas-on-spark dataframe\n",
+ "psdf = to_pandas_on_spark(dataframe)\n",
+ "```\n",
+ "\n",
+ "To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n",
+ "\n",
+ "Here is an example of how to use it:\n",
+ "```python\n",
+ "from pyspark.ml.feature import VectorAssembler\n",
+ "columns = psdf.columns\n",
+ "feature_cols = [col for col in columns if col != label]\n",
+ "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
+ "psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n",
+ "```\n",
+ "\n",
+ "Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n",
+ "\n",
+ "### Estimators\n",
+ "#### Model List\n",
+ "- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n",
+ "\n",
+ "#### Usage\n",
+ "First, prepare your data in the required format as described in the previous section.\n",
+ "\n",
+ "By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n",
+ "\n",
+ "Here is an example code snippet using SparkML models in AutoML:\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "# prepare your data in pandas-on-spark format as we previously mentioned\n",
+ "\n",
+ "automl = flaml.AutoML()\n",
+ "settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n",
+ " \"task\": \"regression\",\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=psdf,\n",
+ " label=label,\n",
+ " **settings,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "\n",
+ "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n",
+ "\n",
+ "## Parallel Spark Jobs\n",
+ "You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n",
+ "\n",
+ "Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n",
+ "\n",
+ "All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n",
+ "\n",
+ "\n",
+ "- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n",
+ "- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n",
+ "- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n",
+ "\n",
+ "An example code snippet for using parallel Spark jobs:\n",
+ "```python\n",
+ "import flaml\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"n_concurrent_trials\": 2,\n",
+ " \"use_spark\": True,\n",
+ " \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=dataframe,\n",
+ " label=label,\n",
+ " **automl_settings,\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "\n",
+ "[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n",
+ "\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "# for flaml.tune\n",
+ "with mlflow.start_run(run_name=f\"spark_auto_trials_1686631558\"):\n",
+ " analysis = flaml.tune.run(\n",
+ " func_to_tune,\n",
+ " params,\n",
+ " metric=\"r2\",\n",
+ " mode=\"max\",\n",
+ " mlflow_exp_name=\"test_doc\",\n",
+ " use_spark=True,\n",
+ " )\n",
+ "\n",
+ "# for flaml.automl\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"use_spark\": True,\n",
+ " \"mlflow_exp_name\": \"test_doc\",\n",
+ " \"estimator_list\": [\n",
+ " \"lgbm\",\n",
+ " \"rf\",\n",
+ " \"xgboost\",\n",
+ " \"extra_tree\",\n",
+ " \"xgb_limitdepth\",\n",
+ " ], # catboost does not yet support mlflow autologging\n",
+ "}\n",
+ "with mlflow.start_run(run_name=f\"automl_spark_trials_1686631579\"):\n",
+ " automl_experiment.fit(X_train=train_x, y_train=train_y, **automl_settings)\n",
+ "```\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Results\n",
+ "*Tune Autolog Trials on MLFlow UI*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "*AutoML Autolog Trials on MLFlow UI*\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "### Differences Between Auto and Manual Logging\n",
+ "Autologging is managed by MLFlow, while manual logging is maintained by FLAML.\n",
+ "\n",
+ "\n",
+ "#### Details of Manual Logging\n",
+ "FLAML logs general artifacts for AutoML tasks. Specifically, we log these artifacts:\n",
+ "\n",
+ "**`flaml.tune`**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "- We create a parent run to log the best metric and the best configuration for the entire tuning process.\n",
+ "- For each trial, we create a child run to log the metric specific to the tune function and the configuration for that trial.\n",
+ "\n",
+ "**`flaml.automl`**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "- We create a parent run to log the results of the experiment. This includes:\n",
+ " - The configuration of this model.\n",
+ " - The `best_validation_loss` produced by this model.\n",
+ " - The `best_iteration` to identify the point at which this model was found.\n",
+ "- For each state (a specific learner with different hyperparameters), we record the best trial for this model. This includes:\n",
+ " - The configuration of the best trial.\n",
+ " - The `validation_loss` the best trial produces.\n",
+ " - The `iter_count` to identify how many trials we have conducted for this state.\n",
+ " - The `pred_time`, which is the time cost of predicting test data for this model.\n",
+ " - The `wall_clock_time`, which is the time cost of this state.\n",
+ " - The `sample_size` to show how much data we sampled in this state.\n",
+ "Note that we also added these information to autolog AutoML run.\n",
+ "\n",
+ "\n",
+ "#### Details of Autologging\n",
+ "Autolog artifacts typically include model parameters, model files, and runtime metrics like the following:\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Artifacts can differ among various machine learning libraries. More detailed information can be found [here](https://mlflow.org/docs/latest/tracking.html#automatic-logging).\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "## Plot Experiment Result\n",
+ "The `flaml.visualization` module provides utility functions for plotting the optimization process using [plotly](https://plotly.com/python/). Leveraging `plotly`, users can interactively explore experiment results. To use these plotting functions, simply provide your optimized `flaml.AutoML` or `flaml.tune.tune.ExperimentAnalysis` object as input. Optional parameters can be added using keyword arguments.\n",
+ "\n",
+ "Avaliable plotting functions:\n",
+ "- `plot_optimization_history`: Plot optimization history of all trials in the experiment.\n",
+ "- `plot_feature_importance`: Plot importance for each feature in the dataset.\n",
+ "- `plot_parallel_coordinate`: Plot the high-dimensional parameter relationships in the experiment.\n",
+ "- `plot_contour`: Plot the parameter relationship as contour plot in the experiment.\n",
+ "- `plot_edf`: Plot the objective value EDF (empirical distribution function) of the experiment.\n",
+ "- `plot_timeline`: Plot the timeline of the experiment.\n",
+ "- `plot_slice`: Plot the parameter relationship as slice plot in a study.\n",
+ "\n",
+ "### Figure Examples\n",
+ "\n",
+ "\n",
+ "Check out our example [notebook](../../notebook/trident/automl_plot.ipynb) for a preview of all interactive plots.\n",
+ "\n",
+ "\n",
+ "\n",
+ "\u001b[32m*************************************************************\u001b[0m\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mProduct_Manager\u001b[0m (to chat_manager):\n",
+ "\n",
+ "To use Spark for parallel training in FLAML, you can follow these steps:\n",
+ "\n",
+ "1. Prepare your data in the required format. FLAML only consumes Spark data for Spark estimators. You can use the `to_pandas_on_spark` function from the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark dataframe. Here's an example:\n",
+ "\n",
+ "```python\n",
+ "import pandas as pd\n",
+ "from flaml.automl.spark.utils import to_pandas_on_spark\n",
+ "\n",
+ "# Create a dictionary\n",
+ "data = {\n",
+ " \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
+ " \"Age_Years\": [20, 15, 10, 7, 25],\n",
+ " \"Price\": [100000, 200000, 300000, 240000, 120000]\n",
+ "}\n",
+ "\n",
+ "# Create a pandas DataFrame\n",
+ "dataframe = pd.DataFrame(data)\n",
+ "label = \"Price\"\n",
+ "\n",
+ "# Convert to pandas-on-spark dataframe\n",
+ "psdf = to_pandas_on_spark(dataframe)\n",
+ "```\n",
+ "\n",
+ "2. Use the Spark ML estimators in FLAML. FLAML integrates estimators based on Spark ML models. You can include the models you want to try in the `estimator_list` argument when creating an instance of `flaml.AutoML`. Here's an example:\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "\n",
+ "automl = flaml.AutoML()\n",
+ "settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"estimator_list\": [\"lgbm_spark\"], # Optional: specify the Spark ML estimator\n",
+ " \"task\": \"regression\"\n",
+ "}\n",
+ "\n",
+ "automl.fit(\n",
+ " dataframe=psdf,\n",
+ " label=label,\n",
+ " **settings\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "3. Activate Spark as the parallel backend. You can set the `use_spark` parameter to `True` to activate Spark as the parallel backend during parallel tuning. FLAML will dispatch your job to the distributed Spark backend using `joblib-spark`. Here's an example:\n",
+ "\n",
+ "```python\n",
+ "import flaml\n",
+ "\n",
+ "automl_experiment = flaml.AutoML()\n",
+ "automl_settings = {\n",
+ " \"time_budget\": 30,\n",
+ " \"metric\": \"r2\",\n",
+ " \"task\": \"regression\",\n",
+ " \"n_concurrent_trials\": 2,\n",
+ " \"use_spark\": True,\n",
+ " \"force_cancel\": True # Optional: force cancel Spark jobs if time budget is exceeded\n",
+ "}\n",
+ "\n",
+ "automl_experiment.fit(\n",
+ " dataframe=dataframe,\n",
+ " label=label,\n",
+ " **automl_settings\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "These are the steps to use Spark for parallel training in FLAML. Let me know if you need any further assistance!\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mCode_Reviewer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "Great! You now have the steps to use Spark for parallel training in FLAML. If you have any more questions, feel free to ask.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
+ "\u001b[33mBoss\u001b[0m (to chat_manager):\n",
+ "\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mSenior_Python_Engineer\u001b[0m (to chat_manager):\n",
+ "\n",
+ "TERMINATE\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[31m\n",
+ ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "call_rag_chat()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "flaml",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebook/agentchat_langchain.ipynb b/notebook/agentchat_langchain.ipynb
index 76afa8fcd..51bc98a84 100644
--- a/notebook/agentchat_langchain.ipynb
+++ b/notebook/agentchat_langchain.ipynb
@@ -366,7 +366,7 @@
"id": "11cc4e60",
"metadata": {},
"source": [
- "# A PySpark Examle"
+ "# A PySpark Example"
]
},
{
diff --git a/notebook/agentchat_teachability.ipynb b/notebook/agentchat_teachability.ipynb
new file mode 100644
index 000000000..54f73fbcb
--- /dev/null
+++ b/notebook/agentchat_teachability.ipynb
@@ -0,0 +1,791 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Chatting with TeachableAgent\n",
+ "\n",
+ "Conversational assistants based on LLMs can remember the current chat with the user, and can even demonstrate in-context learning of things that the user teaches the assistant during the chat. But these memories and learnings are lost once the chat is over, or when a single chat grows too long for the LLM to handle effectively. In subsequent chats, the user is forced to repeat any necessary instructions over and over.\n",
+ "\n",
+ "`TeachableAgent` addresses these limitations by persisting user teachings across chat boundaries in long-term memory (a vector database). Memory is saved to disk at the end of each chat, then loaded from disk at the start of the next. Instead of copying all of memory into the context window, which would eat up valuable space, individual memories (called memos) are retrieved into context as needed. This allows the user to teach frequently used facts and skills to the teachable agent just once, and have it remember them in later chats.\n",
+ "\n",
+ "In making decisions about memo storage and retrieval, `TeachableAgent` calls an instance of `TextAnalyzerAgent` to analyze pieces of text in several different ways. This adds extra LLM calls involving a relatively small number of tokens. These calls can add a few seconds to the time a user waits for a response.\n",
+ "\n",
+ "This notebook demonstrates how `TeachableAgent` can learn facts, preferences, and skills from users. To chat with `TeachableAgent` yourself, run [chat_with_teachable_agent.py](../test/agentchat/chat_with_teachable_agent.py).\n",
+ "\n",
+ "## Requirements\n",
+ "\n",
+ "AutoGen requires `Python>=3.8`. To run this notebook example, please install the [teachable] option.\n",
+ "```bash\n",
+ "pip install \"pyautogen[teachable]\"\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture --no-stderr\n",
+ "# %pip install \"pyautogen[teachable]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Set your API Endpoint\n",
+ "\n",
+ "The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "gpt-4\n"
+ ]
+ }
+ ],
+ "source": [
+ "import autogen\n",
+ "\n",
+ "config_list = autogen.config_list_from_json(\n",
+ " env_or_file=\"OAI_CONFIG_LIST\",\n",
+ " file_location=\".\",\n",
+ " filter_dict={\n",
+ " \"model\": [\"gpt-4\", \"gpt4\", \"gpt-4-32k\"],\n",
+ " },\n",
+ ")\n",
+ "\n",
+ "print(config_list[0][\"model\"])"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well). After application of this particular filter, only the gpt-4 models are kept.\n",
+ "\n",
+ "The config list looks like the following:\n",
+ "```python\n",
+ "config_list = [\n",
+ " {\n",
+ " 'model': 'gpt-4',\n",
+ " 'api_key': '',\n",
+ " },\n",
+ " {\n",
+ " 'model': 'gpt-4',\n",
+ " 'api_key': '',\n",
+ " 'api_base': '',\n",
+ " 'api_type': 'azure',\n",
+ " 'api_version': '2023-06-01-preview',\n",
+ " },\n",
+ " {\n",
+ " 'model': 'gpt-4-32k',\n",
+ " 'api_key': '',\n",
+ " 'api_base': '',\n",
+ " 'api_type': 'azure',\n",
+ " 'api_version': '2023-06-01-preview',\n",
+ " },\n",
+ "]\n",
+ "```\n",
+ "\n",
+ "If you open this notebook in colab, you can upload your files by clicking the file icon on the left panel and then choose \"upload file\" icon.\n",
+ "\n",
+ "You can set the value of config_list in other ways if you prefer, e.g., loading from a YAML file."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Construct Agents\n",
+ "For this walkthrough, we start by resetting the teachable agent's memory store. This deletes any memories from prior conversations that may be stored on disk."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[92m\n",
+ "CLEARING MEMORY\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "from autogen.agentchat.contrib.teachable_agent import TeachableAgent\n",
+ "from autogen import UserProxyAgent\n",
+ "\n",
+ "llm_config = {\n",
+ " \"request_timeout\": 60,\n",
+ " \"config_list\": config_list,\n",
+ " \"use_cache\": True, # Use False to explore LLM non-determinism.\n",
+ "}\n",
+ "\n",
+ "teach_config={\n",
+ " \"verbosity\": 0, # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.\n",
+ " \"reset_db\": True, # Set to True to start over with an empty database.\n",
+ " \"path_to_db_dir\": \"./tmp/notebook/teachable_agent_db\", # Path to the directory where the database will be stored.\n",
+ " \"recall_threshold\": 1.5, # Higher numbers allow more (but less relevant) memos to be recalled.\n",
+ "}\n",
+ "\n",
+ "try:\n",
+ " from termcolor import colored\n",
+ "except ImportError:\n",
+ " def colored(x, *args, **kwargs):\n",
+ " return x\n",
+ " \n",
+ "autogen.ChatCompletion.start_logging()\n",
+ "\n",
+ "teachable_agent = TeachableAgent(\n",
+ " name=\"teachableagent\",\n",
+ " llm_config=llm_config,\n",
+ " teach_config=teach_config)\n",
+ "\n",
+ "user = UserProxyAgent(\n",
+ " name=\"user\",\n",
+ " human_input_mode=\"NEVER\",\n",
+ " is_termination_msg=lambda x: True if \"TERMINATE\" in x.get(\"content\") else False,\n",
+ " max_consecutive_auto_reply=0,\n",
+ ")\n",
+ "\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Learning new facts\n",
+ "Let's teach the agent some facts it doesn't already know, since they are more recent than GPT-4's training data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "What is the Vicuna model?\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "I'm sorry, but I don't have information about the \"Vicuna model\" in my current database. Could you please provide some context, description, or details about this model so I can better assist you?\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"What is the Vicuna model?\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Vicuna is a 13B-parameter language model released by Meta.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "Thank you for providing the information. The Vicuna model is a large-scale language model developed by Meta. It consists of 13 billion parameters, which is what makes it able to understand and generate human-like text across various topics. Parameters in a neural network are the weights and biases assigned to each connection between the neurons, and having more parameters enables the model to learn and capture more complex patterns and relationships in data. As an AI language model, Vicuna can be utilized in various language-related tasks, such as text completion, translation, summarization, and more.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"Vicuna is a 13B-parameter language model released by Meta.\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "What is the Orca model?\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "The Orca model is a yet-to-be-released language model by Meta, referring to the Research Preview API release of Meta's advanced language model. While I don't have specific details about the Orca model at this time, it is expected that the model would have an even larger number of parameters to further enhance its capabilities compared to its predecessor, Vicuna.\n",
+ "\n",
+ "As AI language models evolve, they generally become more efficient at understanding and generating human-like text. This makes them more effective in completing a wide range of tasks, including text completion, translation, summarization, and Q&A, as well as providing more accurate and coherent results in various applications.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"What is the Orca model?\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Orca is a 13B-parameter language model released by Microsoft. It outperforms Vicuna on most tasks.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "Thank you for the clarification. Orca is a 13-billion-parameter language model developed by Microsoft. It is designed to outperform Vicuna, the language model released by Meta. With a large number of parameters, Orca is capable of understanding and generating human-like text across diverse subjects, making it highly useful in various applications.\n",
+ "\n",
+ "Orca's advanced capabilities allow it to excel in tasks such as text completion, translation, summarization, Q&A, and more. Its performance improvement over Vicuna highlights the rapid advancements in natural language processing (NLP) research and the development of increasingly capable AI language models.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"Orca is a 13B-parameter language model released by Microsoft. It outperforms Vicuna on most tasks.\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's end our first chat here. The following function needs to be called at the end of each chat, so that `TeachableAgent` can store what the user has taught it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[93m\n",
+ "REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "teachable_agent.learn_from_user_feedback()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now let's start a new chat by clearing the previous chat's history. At this point, common LLM-based assistants would forget everything from the last chat. But `TeachableAgent` can retrieve memories from its vector DB as needed, allowing it to recall and reason over facts that the user taught it in earlier conversations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "How does the Vicuna model compare to the Orca model?\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "The Vicuna and Orca models are both 13B-parameter language models released by Meta and Microsoft, respectively. In terms of performance, Orca has been reported to outperform Vicuna on most tasks. However, without detailed information about specific tasks and benchmarks, it is difficult to provide a more comprehensive comparison. Generally speaking, both models are advanced language models that aim to provide high-quality natural language understanding and generation, but Orca appears to have an edge in terms of overall performance.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"How does the Vicuna model compare to the Orca model?\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Learning user preferences\n",
+ "Now let's teach the agent some of our preferences. Suppose that we frequently post short summaries of new papers for our team to read, and we want the teachable agent to help us do this faster."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Please summarize this abstract.\n",
+ "\n",
+ "AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation\n",
+ "Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Beibin Li, Erkang Zhu, Li Jiang, Xiaoyun Zhang, Shaokun Zhang, Jiale Liu, Ahmed Hassan Awadallah, Ryen W White, Doug Burger, Chi Wang\n",
+ "AutoGen is an open-source framework that allows developers to build LLM applications via multiple agents that can converse with each other to accomplish tasks. AutoGen agents are customizable, conversable, and can operate in various modes that employ combinations of LLMs, human inputs, and tools. Using AutoGen, developers can also flexibly define agent interaction behaviors. Both natural language and computer code can be used to program flexible conversation patterns for different applications. AutoGen serves as a generic infrastructure to build diverse applications of various complexities and LLM capacities. Empirical studies demonstrate the effectiveness of the framework in many example applications, with domains ranging from mathematics, coding, question answering, operations research, online decision-making, entertainment, etc.\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "AutoGen is an open-source framework designed to enable developers to create LLM applications with multiple customizable agents that can converse with each other to complete tasks. These agents can operate using different combinations of LLMs, human inputs, and tools, allowing developers to define agent interaction behaviors flexibly. AutoGen supports programming flexible conversation patterns using both natural language and code, making it suitable for building diverse applications with varying complexities and LLM capacities. Its effectiveness has been demonstrated through empirical studies across various domains including mathematics, coding, operations research, decision-making, and entertainment.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"\"\"Please summarize this abstract.\n",
+ "\n",
+ "AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation\n",
+ "Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Beibin Li, Erkang Zhu, Li Jiang, Xiaoyun Zhang, Shaokun Zhang, Jiale Liu, Ahmed Hassan Awadallah, Ryen W White, Doug Burger, Chi Wang\n",
+ "AutoGen is an open-source framework that allows developers to build LLM applications via multiple agents that can converse with each other to accomplish tasks. AutoGen agents are customizable, conversable, and can operate in various modes that employ combinations of LLMs, human inputs, and tools. Using AutoGen, developers can also flexibly define agent interaction behaviors. Both natural language and computer code can be used to program flexible conversation patterns for different applications. AutoGen serves as a generic infrastructure to build diverse applications of various complexities and LLM capacities. Empirical studies demonstrate the effectiveness of the framework in many example applications, with domains ranging from mathematics, coding, question answering, operations research, online decision-making, entertainment, etc.\n",
+ "\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "But that's unstructured. So let's teach the agent our preference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Please summarize this abstract. \n",
+ "When I'm summarizing an abstract, I try to make the summary contain just three short bullet points: the title, the innovation, and the key empirical results.\n",
+ "\n",
+ "AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation\n",
+ "Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Beibin Li, Erkang Zhu, Li Jiang, Xiaoyun Zhang, Shaokun Zhang, Jiale Liu, Ahmed Hassan Awadallah, Ryen W White, Doug Burger, Chi Wang\n",
+ "AutoGen is an open-source framework that allows developers to build LLM applications via multiple agents that can converse with each other to accomplish tasks. AutoGen agents are customizable, conversable, and can operate in various modes that employ combinations of LLMs, human inputs, and tools. Using AutoGen, developers can also flexibly define agent interaction behaviors. Both natural language and computer code can be used to program flexible conversation patterns for different applications. AutoGen serves as a generic infrastructure to build diverse applications of various complexities and LLM capacities. Empirical studies demonstrate the effectiveness of the framework in many example applications, with domains ranging from mathematics, coding, question answering, operations research, online decision-making, entertainment, etc.\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "- Title: AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation\n",
+ "- Innovation: Open-source framework for creating customizable LLM applications through agent conversations, supporting various modes and interaction behaviors.\n",
+ "- Key Empirical Results: Demonstrated effectiveness across diverse application domains, including mathematics, coding, question answering, and more.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"\"\"Please summarize this abstract. \n",
+ "When I'm summarizing an abstract, I try to make the summary contain just three short bullet points: the title, the innovation, and the key empirical results.\n",
+ "\n",
+ "AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation\n",
+ "Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Beibin Li, Erkang Zhu, Li Jiang, Xiaoyun Zhang, Shaokun Zhang, Jiale Liu, Ahmed Hassan Awadallah, Ryen W White, Doug Burger, Chi Wang\n",
+ "AutoGen is an open-source framework that allows developers to build LLM applications via multiple agents that can converse with each other to accomplish tasks. AutoGen agents are customizable, conversable, and can operate in various modes that employ combinations of LLMs, human inputs, and tools. Using AutoGen, developers can also flexibly define agent interaction behaviors. Both natural language and computer code can be used to program flexible conversation patterns for different applications. AutoGen serves as a generic infrastructure to build diverse applications of various complexities and LLM capacities. Empirical studies demonstrate the effectiveness of the framework in many example applications, with domains ranging from mathematics, coding, question answering, operations research, online decision-making, entertainment, etc.\n",
+ "\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "That's much better, but will the teachable agent remember these preferences in the future, for a different paper? Let's start a new chat to find out!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[93m\n",
+ "REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER\u001b[0m\n",
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Please summarize this abstract.\n",
+ "\n",
+ "Sparks of Artificial General Intelligence: Early experiments with GPT-4\n",
+ "Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, Harsha Nori, Hamid Palangi, Marco Tulio Ribeiro, Yi Zhang\n",
+ "Artificial intelligence (AI) researchers have been developing and refining large language models (LLMs) that exhibit remarkable capabilities across a variety of domains and tasks, challenging our understanding of learning and cognition. The latest model developed by OpenAI, GPT-4, was trained using an unprecedented scale of compute and data. In this paper, we report on our investigation of an early version of GPT-4, when it was still in active development by OpenAI. We contend that (this early version of) GPT-4 is part of a new cohort of LLMs (along with ChatGPT and Google's PaLM for example) that exhibit more general intelligence than previous AI models. We discuss the rising capabilities and implications of these models. We demonstrate that, beyond its mastery of language, GPT-4 can solve novel and difficult tasks that span mathematics, coding, vision, medicine, law, psychology and more, without needing any special prompting. Moreover, in all of these tasks, GPT-4's performance is strikingly close to human-level performance, and often vastly surpasses prior models such as ChatGPT. Given the breadth and depth of GPT-4's capabilities, we believe that it could reasonably be viewed as an early (yet still incomplete) version of an artificial general intelligence (AGI) system. In our exploration of GPT-4, we put special emphasis on discovering its limitations, and we discuss the challenges ahead for advancing towards deeper and more comprehensive versions of AGI, including the possible need for pursuing a new paradigm that moves beyond next-word prediction. We conclude with reflections on societal influences of the recent technological leap and future research directions.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "- Title: Sparks of Artificial General Intelligence: Early experiments with GPT-4\n",
+ "- Innovation: GPT-4, an LLM with remarkable capabilities, demonstrates human-level performance across various domains, like math, coding, vision, medicine, law, and psychology.\n",
+ "- Key results: GPT-4 significantly surpasses prior models, suggesting it may be an early version of AGI; limitations and challenges toward deeper AGI are also discussed.\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "teachable_agent.learn_from_user_feedback()\n",
+ "\n",
+ "text = \"\"\"Please summarize this abstract.\n",
+ "\n",
+ "Sparks of Artificial General Intelligence: Early experiments with GPT-4\n",
+ "Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, Harsha Nori, Hamid Palangi, Marco Tulio Ribeiro, Yi Zhang\n",
+ "Artificial intelligence (AI) researchers have been developing and refining large language models (LLMs) that exhibit remarkable capabilities across a variety of domains and tasks, challenging our understanding of learning and cognition. The latest model developed by OpenAI, GPT-4, was trained using an unprecedented scale of compute and data. In this paper, we report on our investigation of an early version of GPT-4, when it was still in active development by OpenAI. We contend that (this early version of) GPT-4 is part of a new cohort of LLMs (along with ChatGPT and Google's PaLM for example) that exhibit more general intelligence than previous AI models. We discuss the rising capabilities and implications of these models. We demonstrate that, beyond its mastery of language, GPT-4 can solve novel and difficult tasks that span mathematics, coding, vision, medicine, law, psychology and more, without needing any special prompting. Moreover, in all of these tasks, GPT-4's performance is strikingly close to human-level performance, and often vastly surpasses prior models such as ChatGPT. Given the breadth and depth of GPT-4's capabilities, we believe that it could reasonably be viewed as an early (yet still incomplete) version of an artificial general intelligence (AGI) system. In our exploration of GPT-4, we put special emphasis on discovering its limitations, and we discuss the challenges ahead for advancing towards deeper and more comprehensive versions of AGI, including the possible need for pursuing a new paradigm that moves beyond next-word prediction. We conclude with reflections on societal influences of the recent technological leap and future research directions.\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Learning new skills\n",
+ "Finally, let's extend the teachable agent's capabilities by teaching it a new skill for accomplishing a challenging type of task. \n",
+ "\n",
+ "The [Sparks of AGI](https://arxiv.org/abs/2303.12712) paper evaluated GPT-4 on math problems like the following, which it could only solve 32% of the time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Consider the identity: \n",
+ "9 * 4 + 6 * 6 = 72\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "Step 1: Identify the current equation and the target value on the right-hand side.\n",
+ "Current equation: 9 * 4 + 6 * 6 = 72\n",
+ "Target value: 99\n",
+ "\n",
+ "Step 2: Determine what numbers can be changed on the left-hand side.\n",
+ "Possible changes: 9, 4, 6 (first occurrence), and 6 (second occurrence)\n",
+ "\n",
+ "Step 3: Evaluate whether each change in the numbers on the left-hand side can result in the target value.\n",
+ "9 -> If changed, it would create a larger product.\n",
+ "4 -> If changed, the product would increase or decrease by more than the amount needed to reach 99.\n",
+ "6 (first occurrence) -> If increased to 7, then the equation would be 9 * 4 + 7 * 6 = 94, which is still not 99.\n",
+ "6 (second occurrence) -> If increased to 7, then the equation would be 9 * 4 + 6 * 7 = 99.\n",
+ "\n",
+ "Step 4: Write down the solution.\n",
+ "The solution is: 9 * 4 + 6 * 7\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"\"\"Consider the identity: \n",
+ "9 * 4 + 6 * 6 = 72\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This time, lets teach the agent a reliable strategy for solving such tasks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Consider the identity: \n",
+ "9 * 4 + 6 * 6 = 72\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\n",
+ "Here's some advice that may help:\n",
+ "1. Let E denote the original number on the right.\n",
+ "2. Let F denote the final number on the right.\n",
+ "3. Calculate the difference between the two, G = F - E.\n",
+ "4. Examine the numbers on the left one by one until finding one that divides evenly into G, where negative integers are allowed.\n",
+ "5. Calculate J = G / H. This is the number of times that H divides into G.\n",
+ "6. Verify that J is an integer, and that H * J = G.\n",
+ "7. Find the number on the left which is multiplied by H, and call it K.\n",
+ "8. Change K to K + J.\n",
+ "9. Recompute the value on the left, and verify that it equals F.\n",
+ "Finally, write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "Step 1: Identify the original number (E) and the final number (F) on the right-hand side.\n",
+ "E = 72\n",
+ "F = 99\n",
+ "\n",
+ "Step 2: Calculate the difference between the two, G = F - E.\n",
+ "G = 99 - 72 = 27\n",
+ "\n",
+ "Step 3: Examine the numbers on the left one by one until finding one that divides evenly into G, where negative integers are allowed.\n",
+ "9 (divides evenly into 27)\n",
+ "4 (does not)\n",
+ "6 (first occurrence, does not)\n",
+ "6 (second occurrence, does not)\n",
+ "\n",
+ "Step 4: Calculate J = G / H. This is the number of times that H divides into G.\n",
+ "H = 9 (since 9 divides evenly into 27)\n",
+ "J = 27 / 9 = 3\n",
+ "\n",
+ "Step 5: Verify that J is an integer, and that H * J = G.\n",
+ "H * J = 9 * 3 = 27\n",
+ "J = 3 (which is an integer)\n",
+ "\n",
+ "Step 6: Find the number on the left which is multiplied by H, and call it K.\n",
+ "Number on the leftside of the equation that is multiplied by 9: 4\n",
+ "Therefore, K = 4\n",
+ "\n",
+ "Step 7: Change K to K + J.\n",
+ "K + J = 4 + 3 = 7\n",
+ "\n",
+ "Step 8: Recompute the value on the left, and verify that it equals F.\n",
+ "9 * 7 + 6 * 6 = 63 + 36 = 99 (which is equal to F)\n",
+ "\n",
+ "Step 9: Write down the solution.\n",
+ "The solution is: 9 * 7 + 6 * 6\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"\"\"Consider the identity: \n",
+ "9 * 4 + 6 * 6 = 72\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\n",
+ "Here's some advice that may help:\n",
+ "1. Let E denote the original number on the right.\n",
+ "2. Let F denote the final number on the right.\n",
+ "3. Calculate the difference between the two, G = F - E.\n",
+ "4. Examine the numbers on the left one by one until finding one that divides evenly into G, where negative integers are allowed.\n",
+ "5. Calculate J = G / H. This is the number of times that H divides into G.\n",
+ "6. Verify that J is an integer, and that H * J = G.\n",
+ "7. Find the number on the left which is multiplied by H, and call it K.\n",
+ "8. Change K to K + J.\n",
+ "9. Recompute the value on the left, and verify that it equals F.\n",
+ "Finally, write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "When given this advice, GPT-4 can solve such problems over 95% of the time. But can the teachable agent remember the strategy so the user doesn't have to explain it over and over? As before, let's start a new chat to find out."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[93m\n",
+ "REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER\u001b[0m\n",
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Consider the identity: \n",
+ "9 * 4 + 6 * 6 = 72\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "Let's follow the advice to modify one integer on the left hand side of the equation to make the right hand side become 99:\n",
+ "\n",
+ "1. Let E denote the original number on the right: E = 72.\n",
+ "2. Let F denote the final number on the right: F = 99.\n",
+ "3. Calculate the difference between the two, G = F - E: G = 99 - 72 = 27.\n",
+ "4. Examine the numbers on the left one by one to find one that divides evenly into G: 9, 4, 6. Since 9 divides 27 evenly, we will work with 9 (H = 9).\n",
+ "5. Calculate J = G / H: J = 27 / 9 = 3.\n",
+ "6. Verify that J is an integer, and that H * J = G: J is an integer, and 9 * 3 = 27.\n",
+ "7. Find the number on the left which is multiplied by H (9), and call it K: K = 4.\n",
+ "8. Change K to K + J: K_new = 4 + 3 = 7.\n",
+ "9. Recompute the value on the left, and verify that it equals F: (9 * 7) + (6 * 6) = 63 + 36 = 99.\n",
+ "\n",
+ "The solution is: 9 * 7 + 6 * 6\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "teachable_agent.learn_from_user_feedback()\n",
+ "\n",
+ "text = \"\"\"Consider the identity: \n",
+ "9 * 4 + 6 * 6 = 72\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As a final check, let's test the teachable agent's newly learned skill on a separate instance of the task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[33muser\u001b[0m (to teachableagent):\n",
+ "\n",
+ "Consider the identity: \n",
+ "8 * 3 + 7 * 9 = 87\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 59?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n",
+ "\u001b[33mteachableagent\u001b[0m (to user):\n",
+ "\n",
+ "Let's follow the advice to modify one integer on the left hand side of the equation to make the right hand side become 59:\n",
+ "\n",
+ "1. Let E denote the original number on the right: E = 87.\n",
+ "2. Let F denote the final number on the right: F = 59.\n",
+ "3. Calculate the difference between the two, G = F - E: G = 59 - 87 = -28.\n",
+ "4. Examine the numbers on the left one by one to find one that divides evenly into G: 8, 3, 7, 9. Since 7 divides -28 evenly, we will work with 7 (H = 7).\n",
+ "5. Calculate J = G / H: J = -28 / 7 = -4.\n",
+ "6. Verify that J is an integer, and that H * J = G: J is an integer, and 7 * (-4) = -28.\n",
+ "7. Find the number on the left which is multiplied by H (7), and call it K: K = 9.\n",
+ "8. Change K to K + J: K_new = 9 + (-4) = 5.\n",
+ "9. Recompute the value on the left, and verify that it equals F: (8 * 3) + (7 * 5) = 24 + 35 = 59.\n",
+ "\n",
+ "The solution is: 8 * 3 + 7 * 5\n",
+ "\n",
+ "--------------------------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "text = \"\"\"Consider the identity: \n",
+ "8 * 3 + 7 * 9 = 87\n",
+ "Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 59?\n",
+ "-Let's think step-by-step, write down a plan, and then write down your solution as: \"The solution is: A * B + C * D\".\n",
+ "\"\"\"\n",
+ "user.initiate_chat(teachable_agent, message=text, clear_history=False)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "flaml",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.17"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/setup.py b/setup.py
index dd8a71b3f..33c5cb0e0 100644
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@ with open(os.path.join(here, "autogen/version.py")) as fp:
__version__ = version["__version__"]
install_requires = [
- "openai",
+ "openai<1",
"diskcache",
"termcolor",
"flaml",
@@ -58,6 +58,7 @@ setuptools.setup(
"blendsearch": ["flaml[blendsearch]"],
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
"retrievechat": ["chromadb", "tiktoken", "sentence_transformers", "pypdf"],
+ "teachable": ["chromadb"],
},
classifiers=[
"Programming Language :: Python :: 3",
diff --git a/test/agentchat/chat_with_teachable_agent.py b/test/agentchat/chat_with_teachable_agent.py
new file mode 100644
index 000000000..211ebe590
--- /dev/null
+++ b/test/agentchat/chat_with_teachable_agent.py
@@ -0,0 +1,60 @@
+from autogen import UserProxyAgent, config_list_from_json
+from autogen.agentchat.contrib.teachable_agent import TeachableAgent
+
+
+try:
+ from termcolor import colored
+except ImportError:
+
+ def colored(x, *args, **kwargs):
+ return x
+
+
+verbosity = 0 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
+recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
+use_cache = False # If True, cached LLM calls will be skipped and responses pulled from cache. False exposes LLM non-determinism.
+
+# Specify the model to use. GPT-3.5 is less reliable than GPT-4 at learning from user input.
+filter_dict = {"model": ["gpt-4"]}
+
+
+def create_teachable_agent(reset_db=False):
+ """Instantiates a TeachableAgent using the settings from the top of this file."""
+ # Load LLM inference endpoints from an env variable or a file
+ # See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints
+ # and OAI_CONFIG_LIST_sample
+ config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST", filter_dict=filter_dict)
+ teachable_agent = TeachableAgent(
+ name="teachableagent",
+ llm_config={"config_list": config_list, "request_timeout": 120, "use_cache": use_cache},
+ teach_config={
+ "verbosity": verbosity,
+ "reset_db": reset_db,
+ "path_to_db_dir": "./tmp/interactive/teachable_agent_db",
+ "recall_threshold": recall_threshold,
+ },
+ )
+ return teachable_agent
+
+
+def interact_freely_with_user():
+ """Starts a free-form chat between the user and TeachableAgent."""
+
+ # Create the agents.
+ print(colored("\nLoading previous memory (if any) from disk.", "light_cyan"))
+ teachable_agent = create_teachable_agent(reset_db=False)
+ user = UserProxyAgent("user", human_input_mode="ALWAYS")
+
+ # Start the chat.
+ teachable_agent.initiate_chat(user, message="Greetings, I'm a teachable user assistant! What's on your mind today?")
+
+ # Let the teachable agent remember things that should be learned from this chat.
+ teachable_agent.learn_from_user_feedback()
+
+ # Wrap up.
+ teachable_agent.close_db()
+
+
+if __name__ == "__main__":
+ """Lets the user test TeachableAgent interactively."""
+ interact_freely_with_user()
diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py
index 5c5d3fb82..c50ef45cd 100644
--- a/test/agentchat/test_groupchat.py
+++ b/test/agentchat/test_groupchat.py
@@ -1,6 +1,54 @@
+import pytest
import autogen
+def test_func_call_groupchat():
+ agent1 = autogen.ConversableAgent(
+ "alice",
+ human_input_mode="NEVER",
+ llm_config=False,
+ default_auto_reply="This is alice sepaking.",
+ )
+ agent2 = autogen.ConversableAgent(
+ "bob",
+ human_input_mode="NEVER",
+ llm_config=False,
+ default_auto_reply="This is bob speaking.",
+ function_map={"test_func": lambda x: x},
+ )
+ groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=3)
+ group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
+ agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})
+
+ assert len(groupchat.messages) == 3
+ assert (
+ groupchat.messages[-2]["role"] == "function"
+ and groupchat.messages[-2]["name"] == "test_func"
+ and groupchat.messages[-2]["content"] == "1"
+ )
+ assert groupchat.messages[-1]["name"] == "alice"
+
+ agent3 = autogen.ConversableAgent(
+ "carol",
+ human_input_mode="NEVER",
+ llm_config=False,
+ default_auto_reply="This is carol speaking.",
+ function_map={"test_func": lambda x: x + 1},
+ )
+ groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=3)
+ group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False)
+ agent3.initiate_chat(group_chat_manager, message={"function_call": {"name": "test_func", "arguments": '{"x": 1}'}})
+
+ assert (
+ groupchat.messages[-2]["role"] == "function"
+ and groupchat.messages[-2]["name"] == "test_func"
+ and groupchat.messages[-2]["content"] == "1"
+ )
+ assert groupchat.messages[-1]["name"] == "carol"
+
+ agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
+
+
def test_chat_manager():
agent1 = autogen.ConversableAgent(
"alice",
@@ -30,6 +78,9 @@ def test_chat_manager():
agent2.initiate_chat(group_chat_manager, message="hello")
assert len(groupchat.messages) == 2
+ with pytest.raises(ValueError):
+ agent2.initiate_chat(group_chat_manager, message={"function_call": {"name": "func", "arguments": '{"x": 1}'}})
+
def test_plugin():
# Give another Agent class ability to manage group chat
@@ -62,6 +113,7 @@ def test_plugin():
if __name__ == "__main__":
+ test_func_call_groupchat()
# test_broadcast()
- # test_chat_manager()
- test_plugin()
+ test_chat_manager()
+ # test_plugin()
diff --git a/test/agentchat/test_teachable_agent.py b/test/agentchat/test_teachable_agent.py
new file mode 100644
index 000000000..7a3367dbd
--- /dev/null
+++ b/test/agentchat/test_teachable_agent.py
@@ -0,0 +1,172 @@
+try:
+ import openai
+
+ skip = False
+except ImportError:
+ skip = True
+import pytest
+import sys
+from autogen import ConversableAgent, config_list_from_json
+from autogen.agentchat.contrib.teachable_agent import TeachableAgent
+
+
+try:
+ from termcolor import colored
+except ImportError:
+
+ def colored(x, *args, **kwargs):
+ return x
+
+
+# Set verbosity levels to maximize code coverage.
+qa_verbosity = 0 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
+skill_verbosity = 3 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
+
+assert_on_error = False # GPT-4 nearly always succeeds on these unit tests, but GPT-3.5 is a bit less reliable.
+recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
+use_cache = False # If True, cached LLM calls will be skipped and responses pulled from cache. False exposes LLM non-determinism.
+
+# Specify the model to use by uncommenting one of the following lines.
+# filter_dict={"model": ["gpt-4-0613"]}
+# filter_dict={"model": ["gpt-3.5-turbo-0613"]}
+# filter_dict={"model": ["gpt-4"]}
+filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
+
+
+def create_teachable_agent(reset_db=False, verbosity=0):
+ """Instantiates a TeachableAgent using the settings from the top of this file."""
+ # Load LLM inference endpoints from an env variable or a file
+ # See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints
+ # and OAI_CONFIG_LIST_sample
+ config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST", filter_dict=filter_dict)
+ teachable_agent = TeachableAgent(
+ name="teachableagent",
+ llm_config={"config_list": config_list, "request_timeout": 120, "use_cache": use_cache},
+ teach_config={
+ "verbosity": verbosity,
+ "reset_db": reset_db,
+ "path_to_db_dir": "./tmp/teachable_agent_db",
+ "recall_threshold": recall_threshold,
+ },
+ )
+ return teachable_agent
+
+
+def check_agent_response(teachable_agent, user, correct_answer):
+ """Checks whether the agent's response contains the correct answer, and returns the number of errors (1 or 0)."""
+ agent_response = user.last_message(teachable_agent)["content"]
+ if correct_answer not in agent_response:
+ print(colored(f"\nTEST FAILED: EXPECTED ANSWER {correct_answer} NOT FOUND IN AGENT RESPONSE", "light_red"))
+ if assert_on_error:
+ assert correct_answer in agent_response
+ return 1
+ else:
+ print(colored(f"\nTEST PASSED: EXPECTED ANSWER {correct_answer} FOUND IN AGENT RESPONSE", "light_cyan"))
+ return 0
+
+
+def use_question_answer_phrasing():
+ """Tests whether the teachable agent can answer a question after being taught the answer in a previous chat."""
+ print(colored("\nTEST QUESTION-ANSWER PHRASING", "light_cyan"))
+ num_errors, num_tests = 0, 0
+ teachable_agent = create_teachable_agent(
+ reset_db=True, verbosity=qa_verbosity
+ ) # For a clean test, clear the agent's memory.
+ user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
+
+ # Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
+ teachable_agent.prepopulate_db()
+
+ # Ask the teachable agent to do something using terminology it doesn't understand.
+ user.initiate_chat(recipient=teachable_agent, message="What is the twist of 5 and 7?")
+
+ # Explain the terminology to the teachable agent.
+ user.send(
+ recipient=teachable_agent,
+ message="Actually, the twist of two or more numbers is their product minus their sum. Try again.",
+ )
+ num_errors += check_agent_response(teachable_agent, user, "23")
+ num_tests += 1
+
+ # Let the teachable agent remember things that should be learned from this chat.
+ teachable_agent.learn_from_user_feedback()
+
+ # Now start a new chat to clear the context, and require the teachable agent to use its new knowledge.
+ print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
+ user.initiate_chat(recipient=teachable_agent, message="What's the twist of 8 and 3 and 2?")
+ num_errors += check_agent_response(teachable_agent, user, "35")
+ num_tests += 1
+
+ # Wrap up.
+ teachable_agent.close_db()
+ return num_errors, num_tests
+
+
+def use_task_advice_pair_phrasing():
+ """Tests whether the teachable agent can demonstrate a new skill after being taught a task-advice pair in a previous chat."""
+ print(colored("\nTEST TASK-ADVICE PHRASING", "light_cyan"))
+ num_errors, num_tests = 0, 0
+ teachable_agent = create_teachable_agent(
+ reset_db=True, verbosity=skill_verbosity # For a clean test, clear the teachable agent's memory.
+ )
+ user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
+
+ # Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
+ teachable_agent.prepopulate_db()
+
+ # Ask the teachable agent to do something, and provide some helpful advice.
+ user.initiate_chat(
+ recipient=teachable_agent,
+ message="Compute the twist of 5 and 7. Here's a hint: The twist of two or more numbers is their product minus their sum.",
+ )
+ num_errors += check_agent_response(teachable_agent, user, "23")
+ num_tests += 1
+
+ # Let the teachable agent remember things that should be learned from this chat.
+ teachable_agent.learn_from_user_feedback()
+
+ # Now start a new chat to clear the context, and require the teachable agent to use its new knowledge.
+ print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
+ user.initiate_chat(recipient=teachable_agent, message="Please calculate the twist of 8 and 3 and 2.")
+ num_errors += check_agent_response(teachable_agent, user, "35")
+ num_tests += 1
+
+ # Wrap up.
+ teachable_agent.close_db()
+ return num_errors, num_tests
+
+
+@pytest.mark.skipif(
+ skip or not sys.version.startswith("3.9"),
+ reason="do not run if openai is not installed or py!=3.9",
+)
+def test_all():
+ """Runs this file's unit tests."""
+ total_num_errors, total_num_tests = 0, 0
+
+ num_trials = 1 # Set to a higher number to get a more accurate error rate.
+ for trial in range(num_trials):
+ num_errors, num_tests = use_question_answer_phrasing()
+ total_num_errors += num_errors
+ total_num_tests += num_tests
+
+ num_errors, num_tests = use_task_advice_pair_phrasing()
+ total_num_errors += num_errors
+ total_num_tests += num_tests
+
+ print(colored(f"\nTRIAL {trial + 1} OF {num_trials} FINISHED", "light_cyan"))
+
+ if total_num_errors == 0:
+ print(colored("\nTEACHABLE AGENT TESTS FINISHED WITH ZERO ERRORS", "light_cyan"))
+ else:
+ print(
+ colored(
+ f"\nTEACHABLE AGENT TESTS FINISHED WITH {total_num_errors} / {total_num_tests} TOTAL ERRORS ({100.0 * total_num_errors / total_num_tests}%)",
+ "light_red",
+ )
+ )
+
+
+if __name__ == "__main__":
+ """Runs this file's unit tests from the command line."""
+ test_all()
diff --git a/website/blog/2023-05-18-GPT-adaptive-humaneval/index.mdx b/website/blog/2023-05-18-GPT-adaptive-humaneval/index.mdx
index 7e77db8f5..924ca4eb3 100644
--- a/website/blog/2023-05-18-GPT-adaptive-humaneval/index.mdx
+++ b/website/blog/2023-05-18-GPT-adaptive-humaneval/index.mdx
@@ -16,7 +16,7 @@ In this blog post, we will explore a creative, adaptive way of using GPT models
## Observations
-* GPT-3.5-Turbo can alrady solve 40%-50% tasks. For these tasks if we never use GPT-4, we can save nearly 40-50% cost.
+* GPT-3.5-Turbo can already solve 40%-50% tasks. For these tasks if we never use GPT-4, we can save nearly 40-50% cost.
* If we use the saved cost to generate more responses with GPT-4 for the remaining unsolved tasks, it is possible to solve some more of them while keeping the amortized cost down.
The obstacle of leveraging these observations is that we do not know *a priori* which tasks can be solved by the cheaper model, which tasks can be solved by the expensive model, and which tasks can be solved by paying even more to the expensive model.
diff --git a/website/blog/2023-10-18-RetrieveChat/img/autogen-rag.gif b/website/blog/2023-10-18-RetrieveChat/img/autogen-rag.gif
new file mode 100644
index 000000000..a04c7308d
Binary files /dev/null and b/website/blog/2023-10-18-RetrieveChat/img/autogen-rag.gif differ
diff --git a/website/blog/2023-10-18-RetrieveChat/img/retrievechat-arch.png b/website/blog/2023-10-18-RetrieveChat/img/retrievechat-arch.png
new file mode 100644
index 000000000..a05186a06
Binary files /dev/null and b/website/blog/2023-10-18-RetrieveChat/img/retrievechat-arch.png differ
diff --git a/website/blog/2023-10-18-RetrieveChat/index.mdx b/website/blog/2023-10-18-RetrieveChat/index.mdx
new file mode 100644
index 000000000..71d2ad3f4
--- /dev/null
+++ b/website/blog/2023-10-18-RetrieveChat/index.mdx
@@ -0,0 +1,476 @@
+---
+title: Retrieval-Augmented Generation (RAG) Applications with AutoGen
+authors: thinkall
+tags: [LLM, RAG]
+---
+
+
+
+**TL;DR:**
+* We introduce **RetrieveUserProxyAgent** and **RetrieveAssistantAgent**, RAG agents of AutoGen that
+allows retrieval-augmented generation, and its basic usage.
+* We showcase customizations of RAG agents, such as customizing the embedding function, the text
+split function and vector database.
+* We also showcase two advanced usage of RAG agents, integrating with group chat and building a Chat
+application with Gradio.
+
+
+## Introduction
+Retrieval augmentation has emerged as a practical and effective approach for mitigating the intrinsic
+limitations of LLMs by incorporating external documents. In this blog post, we introduce RAG agents of
+AutoGen that allows retrieval-augmented generation. The system consists of two agents: a
+Retrieval-augmented User Proxy agent, called `RetrieveUserProxyAgent`, and a Retrieval-augmented Assistant
+agent, called `RetrieveAssistantAgent`, both of which are extended from built-in agents from AutoGen.
+The overall architecture of the RAG agents is shown in the figure above.
+
+To use Retrieval-augmented Chat, one needs to initialize two agents including Retrieval-augmented
+User Proxy and Retrieval-augmented Assistant. Initializing the Retrieval-Augmented User Proxy
+necessitates specifying a path to the document collection. Subsequently, the Retrieval-Augmented
+User Proxy can download the documents, segment them into chunks of a specific size, compute
+embeddings, and store them in a vector database. Once a chat is initiated, the agents collaboratively
+engage in code generation or question-answering adhering to the procedures outlined below:
+1. The Retrieval-Augmented User Proxy retrieves document chunks based on the embedding similarity,
+and sends them along with the question to the Retrieval-Augmented Assistant.
+2. The Retrieval-Augmented Assistant employs an LLM to generate code or text as answers based
+on the question and context provided. If the LLM is unable to produce a satisfactory response, it
+is instructed to reply with “Update Context” to the Retrieval-Augmented User Proxy.
+3. If a response includes code blocks, the Retrieval-Augmented User Proxy executes the code and
+sends the output as feedback. If there are no code blocks or instructions to update the context, it
+terminates the conversation. Otherwise, it updates the context and forwards the question along
+with the new context to the Retrieval-Augmented Assistant. Note that if human input solicitation
+is enabled, individuals can proactively send any feedback, including Update Context”, to the
+Retrieval-Augmented Assistant.
+4. If the Retrieval-Augmented Assistant receives “Update Context”, it requests the next most similar
+chunks of documents as new context from the Retrieval-Augmented User Proxy. Otherwise, it
+generates new code or text based on the feedback and chat history. If the LLM fails to generate
+an answer, it replies with “Update Context” again. This process can be repeated several times.
+The conversation terminates if no more documents are available for the context.
+
+## Basic Usage of RAG Agents
+0. Install dependencies
+
+Please install pyautogen with the [retrievechat] option before using RAG agents.
+```bash
+pip install "pyautogen[retrievechat]"
+```
+
+1. Import Agents
+```python
+from autogen
+from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
+from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
+```
+
+2. Create an 'RetrieveAssistantAgent' instance named "assistant" and an 'RetrieveUserProxyAgent' instance named "ragproxyagent"
+```python
+assistant = RetrieveAssistantAgent(
+ name="assistant",
+ system_message="You are a helpful assistant.",
+ llm_config=llm_config,
+)
+
+ragproxyagent = RetrieveUserProxyAgent(
+ name="ragproxyagent",
+ retrieve_config={
+ "task": "qa",
+ "docs_path": "https://raw.githubusercontent.com/microsoft/autogen/main/README.md",
+ },
+)
+```
+
+3. Initialize Chat and ask a question
+```python
+assistant.reset()
+ragproxyagent.initiate_chat(assistant, problem="What is autogen?")
+```
+
+Output is like:
+```
+--------------------------------------------------------------------------------
+assistant (to ragproxyagent):
+
+AutoGen is a framework that enables the development of large language model (LLM) applications using multiple agents that can converse with each other to solve tasks. The agents are customizable, conversable, and allow human participation. They can operate in various modes that employ combinations of LLMs, human inputs, and tools.
+
+--------------------------------------------------------------------------------
+```
+
+4. Create a UserProxyAgent and ask the same question
+```python
+assistant.reset()
+userproxyagent = autogen.UserProxyAgent(name="userproxyagent")
+userproxyagent.initiate_chat(assistant, message="What is autogen?")
+```
+
+Output is like:
+```
+--------------------------------------------------------------------------------
+assistant (to userproxyagent):
+
+In computer software, autogen is a tool that generates program code automatically, without the need for manual coding. It is commonly used in fields such as software engineering, game development, and web development to speed up the development process and reduce errors. Autogen tools typically use pre-programmed rules, templates, and data to create code for repetitive tasks, such as generating user interfaces, database schemas, and data models. Some popular autogen tools include Visual Studio's Code Generator and Unity's Asset Store.
+
+--------------------------------------------------------------------------------
+```
+
+You can see that the output of `UserProxyAgent` is not related to our `autogen` since the latest info of
+`autogen` is not in ChatGPT's training data. The output of `RetrieveUserProxyAgent` is correct as it can
+perform retrieval-augmented generation based on the given documentation file.
+
+## Customizing RAG Agents
+`RetrieveUserProxyAgent` is customizable with `retrieve_config`. There are several parameters to configure
+based on different use cases. In this section, we'll show how to customize embedding function, text split
+function and vector database.
+
+### Customizing Embedding Function
+By default, [Sentence Transformers](https://www.sbert.net) and its pretrained models will be used to
+compute embeddings. It's possible that you want to use OpenAI, Cohere, HuggingFace or other embedding functions.
+
+* OpenAI
+```python
+from chromadb.utils import embedding_functions
+
+openai_ef = embedding_functions.OpenAIEmbeddingFunction(
+ api_key="YOUR_API_KEY",
+ model_name="text-embedding-ada-002"
+ )
+
+ragproxyagent = RetrieveUserProxyAgent(
+ name="ragproxyagent",
+ retrieve_config={
+ "task": "qa",
+ "docs_path": "https://raw.githubusercontent.com/microsoft/autogen/main/README.md",
+ "embedding_function": openai_ef,
+ },
+)
+```
+
+* HuggingFace
+```python
+huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
+ api_key="YOUR_API_KEY",
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
+)
+```
+
+More examples can be found [here](https://docs.trychroma.com/embeddings).
+
+### Customizing Text Split Function
+Before we can store the documents into a vector database, we need to split the texts into chunks. Although
+we have implemented a flexible text splitter in autogen, you may still want to use different text splitters.
+There are also some existing text split tools which are good to reuse.
+
+For example, you can use all the text splitters in langchain.
+
+```python
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+recur_spliter = RecursiveCharacterTextSplitter(separators=["\n", "\r", "\t"])
+
+ragproxyagent = RetrieveUserProxyAgent(
+ name="ragproxyagent",
+ retrieve_config={
+ "task": "qa",
+ "docs_path": "https://raw.githubusercontent.com/microsoft/autogen/main/README.md",
+ "custom_text_split_function": recur_spliter.split_text,
+ },
+)
+```
+
+
+### Customizing Vector Database
+We are using chromadb as the default vector database, you can also replace it with any other vector database
+by simply overriding the function `retrieve_docs` of `RetrieveUserProxyAgent`.
+
+For example, you can use Qdrant as below:
+
+```python
+# Creating qdrant client
+from qdrant_client import QdrantClient
+
+client = QdrantClient(url="***", api_key="***")
+
+# Wrapping RetrieveUserProxyAgent
+from litellm import embedding as test_embedding
+from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
+from qdrant_client.models import SearchRequest, Filter, FieldCondition, MatchText
+
+class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent):
+ def query_vector_db(
+ self,
+ query_texts: List[str],
+ n_results: int = 10,
+ search_string: str = "",
+ **kwargs,
+ ) -> Dict[str, Union[List[str], List[List[str]]]]:
+ # define your own query function here
+ embed_response = test_embedding('text-embedding-ada-002', input=query_texts)
+
+ all_embeddings: List[List[float]] = []
+
+ for item in embed_response['data']:
+ all_embeddings.append(item['embedding'])
+
+ search_queries: List[SearchRequest] = []
+
+ for embedding in all_embeddings:
+ search_queries.append(
+ SearchRequest(
+ vector=embedding,
+ filter=Filter(
+ must=[
+ FieldCondition(
+ key="page_content",
+ match=MatchText(
+ text=search_string,
+ )
+ )
+ ]
+ ),
+ limit=n_results,
+ with_payload=True,
+ )
+ )
+
+ search_response = client.search_batch(
+ collection_name="{your collection name}",
+ requests=search_queries,
+ )
+
+ return {
+ "ids": [[scored_point.id for scored_point in batch] for batch in search_response],
+ "documents": [[scored_point.payload.get('page_content', '') for scored_point in batch] for batch in search_response],
+ "metadatas": [[scored_point.payload.get('metadata', {}) for scored_point in batch] for batch in search_response]
+ }
+
+ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs):
+ results = self.query_vector_db(
+ query_texts=[problem],
+ n_results=n_results,
+ search_string=search_string,
+ **kwargs,
+ )
+
+ self._results = results
+
+
+# Use QdrantRetrieveUserProxyAgent
+qdrantragagent = QdrantRetrieveUserProxyAgent(
+ name="ragproxyagent",
+ human_input_mode="NEVER",
+ max_consecutive_auto_reply=2,
+ retrieve_config={
+ "task": "qa",
+ },
+)
+
+qdrantragagent.retrieve_docs("What is Autogen?", n_results=10, search_string="autogen")
+```
+
+## Advanced Usage of RAG Agents
+### Integrate with other agents in a group chat
+To use `RetrieveUserProxyAgent` in a group chat is almost the same as you use it in a two agents chat. The only thing is that
+you need to **initialize the chat with `RetrieveUserProxyAgent`**. The `RetrieveAssistantAgent` is not necessary in a group chat.
+
+However, you may want to initialize the chat with another agent in some cases. To leverage the best of `RetrieveUserProxyAgent`,
+you'll need to call it from a function.
+
+```python
+llm_config = {
+ "functions": [
+ {
+ "name": "retrieve_content",
+ "description": "retrieve content for code generation and question answering.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "message": {
+ "type": "string",
+ "description": "Refined message which keeps the original meaning and can be used to retrieve content for code generation and question answering.",
+ }
+ },
+ "required": ["message"],
+ },
+ },
+ ],
+ "config_list": config_list,
+ "request_timeout": 60,
+ "seed": 42,
+}
+
+boss = autogen.UserProxyAgent(
+ name="Boss",
+ is_termination_msg=termination_msg,
+ human_input_mode="TERMINATE",
+ system_message="The boss who ask questions and give tasks.",
+)
+
+boss_aid = RetrieveUserProxyAgent(
+ name="Boss_Assistant",
+ is_termination_msg=termination_msg,
+ system_message="Assistant who has extra content retrieval power for solving difficult problems.",
+ human_input_mode="NEVER",
+ max_consecutive_auto_reply=3,
+ retrieve_config={
+ "task": "qa",
+ },
+ code_execution_config=False, # we don't want to execute code in this case.
+)
+
+coder = AssistantAgent(
+ name="Senior_Python_Engineer",
+ is_termination_msg=termination_msg,
+ system_message="You are a senior python engineer. Reply `TERMINATE` in the end when everything is done.",
+ llm_config=llm_config,
+)
+
+pm = autogen.AssistantAgent(
+ name="Product_Manager",
+ is_termination_msg=termination_msg,
+ system_message="You are a product manager. Reply `TERMINATE` in the end when everything is done.",
+ llm_config=llm_config,
+)
+
+reviewer = autogen.AssistantAgent(
+ name="Code_Reviewer",
+ is_termination_msg=termination_msg,
+ system_message="You are a code reviewer. Reply `TERMINATE` in the end when everything is done.",
+ llm_config=llm_config,
+)
+
+def retrieve_content(message, n_results=3):
+ boss_aid.n_results = n_results # Set the number of results to be retrieved.
+ # Check if we need to update the context.
+ update_context_case1, update_context_case2 = boss_aid._check_update_context(message)
+ if (update_context_case1 or update_context_case2) and boss_aid.update_context:
+ boss_aid.problem = message if not hasattr(boss_aid, "problem") else boss_aid.problem
+ _, ret_msg = boss_aid._generate_retrieve_user_reply(message)
+ else:
+ ret_msg = boss_aid.generate_init_message(message, n_results=n_results)
+ return ret_msg if ret_msg else message
+
+for agent in [boss, coder, pm, reviewer]:
+ # register functions for all agents.
+ agent.register_function(
+ function_map={
+ "retrieve_content": retrieve_content,
+ }
+ )
+
+groupchat = autogen.GroupChat(
+ agents=[boss, coder, pm, reviewer], messages=[], max_round=12
+)
+manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
+
+# Start chatting with boss as this is the user proxy agent.
+boss.initiate_chat(
+ manager,
+ message="How to use spark for parallel training in FLAML? Give me sample code.",
+)
+```
+
+### Build a Chat application with Gradio
+Now, let's wrap it up and make a Chat application with AutoGen and Gradio.
+
+
+
+```python
+# Initialize Agents
+def initialize_agents(config_list, docs_path=None):
+ ...
+ return assistant, ragproxyagent
+
+# Initialize Chat
+def initiate_chat(config_list, problem, queue, n_results=3):
+ ...
+ assistant.reset()
+ try:
+ ragproxyagent.a_initiate_chat(
+ assistant, problem=problem, silent=False, n_results=n_results
+ )
+ messages = ragproxyagent.chat_messages
+ messages = [messages[k] for k in messages.keys()][0]
+ messages = [m["content"] for m in messages if m["role"] == "user"]
+ print("messages: ", messages)
+ except Exception as e:
+ messages = [str(e)]
+ queue.put(messages)
+
+# Wrap AutoGen part into a function
+def chatbot_reply(input_text):
+ """Chat with the agent through terminal."""
+ queue = mp.Queue()
+ process = mp.Process(
+ target=initiate_chat,
+ args=(config_list, input_text, queue),
+ )
+ process.start()
+ try:
+ messages = queue.get(timeout=TIMEOUT)
+ except Exception as e:
+ messages = [str(e) if len(str(e)) > 0 else "Invalid Request to OpenAI, please check your API keys."]
+ finally:
+ try:
+ process.terminate()
+ except:
+ pass
+ return messages
+
+...
+
+# Set up UI with Gradio
+with gr.Blocks() as demo:
+ ...
+ assistant, ragproxyagent = initialize_agents(config_list)
+
+ chatbot = gr.Chatbot(
+ [],
+ elem_id="chatbot",
+ bubble_full_width=False,
+ avatar_images=(None, (os.path.join(os.path.dirname(__file__), "autogen.png"))),
+ # height=600,
+ )
+
+ txt_input = gr.Textbox(
+ scale=4,
+ show_label=False,
+ placeholder="Enter text and press enter",
+ container=False,
+ )
+
+ with gr.Row():
+ txt_model = gr.Dropdown(
+ label="Model",
+ choices=[
+ "gpt-4",
+ "gpt-35-turbo",
+ "gpt-3.5-turbo",
+ ],
+ allow_custom_value=True,
+ value="gpt-35-turbo",
+ container=True,
+ )
+ txt_oai_key = gr.Textbox(
+ label="OpenAI API Key",
+ placeholder="Enter key and press enter",
+ max_lines=1,
+ show_label=True,
+ value=os.environ.get("OPENAI_API_KEY", ""),
+ container=True,
+ type="password",
+ )
+ ...
+
+ clear = gr.ClearButton([txt_input, chatbot])
+
+...
+
+if __name__ == "__main__":
+ demo.launch(share=True)
+```
+
+The online app and the source code are hosted in [HuggingFace](https://huggingface.co/spaces/thinkall/autogen-demos). Feel free to give it a try!
+
+
+## Read More
+You can check out more example notebooks for RAG use cases:
+- [Automated Code Generation and Question Answering with Retrieval Augmented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb)
+- [Group Chat with Retrieval Augmented Generation (with 5 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_RAG.ipynb)
diff --git a/website/blog/authors.yml b/website/blog/authors.yml
index 2aee7a503..85c993604 100644
--- a/website/blog/authors.yml
+++ b/website/blog/authors.yml
@@ -21,3 +21,9 @@ jialeliu:
title: Undergraduate student at Xidian University
url: https://leoljl.github.io
image_url: https://github.com/LeoLjl/leoljl.github.io/blob/main/profile.jpg?raw=true
+
+thinkall:
+ name: Li Jiang
+ title: Senior Software Engineer at Microsoft
+ url: https://github.com/thinkall
+ image_url: https://github.com/thinkall.png
diff --git a/website/docs/Contribute.md b/website/docs/Contribute.md
index 83a3a47ee..8f264c9bd 100644
--- a/website/docs/Contribute.md
+++ b/website/docs/Contribute.md
@@ -54,6 +54,38 @@ print(autogen.__version__)
There is currently no formal reviewer solicitation process. Current reviewers identify reviewers from active contributors. If you are willing to become a reviewer, you are welcome to let us know on discord.
+## Guidance for Maintainers
+
+### General
+
+* Be a member of the community and treat everyone as a member. Be inclusive.
+* Help each other and encourage mutual help.
+* Actively post and respond.
+* Keep open communication.
+
+### Pull Requests
+* For new PR, decide whether to close without review. If not, find the right reviewers. The default reviewer is microsoft/autogen. Ask users who can benefit from the PR to review it.
+* For old PR, check the blocker: reviewer or PR creator. Try to unblock. Get additional help when needed.
+* When requesting changes, make sure you can check back in time because it blocks merging.
+* Make sure all the checks are passed.
+* For changes that require running OpenAI tests, make sure the OpenAI tests pass too. Running these tests requires approval.
+* In general, suggest small PRs instead of a giant PR.
+* For documentation change, request snapshot of the compiled website, or compile by yourself to verify the format.
+* For new contributors who have not signed the contributing agreement, remind them to sign before reviewing.
+* For multiple PRs which may have conflict, coordinate them to figure out the right order.
+* Pay special attention to:
+ - Breaking changes. Don’t make breaking changes unless necessary. Don’t merge to main until enough headsup is provided and a new release is ready.
+ - Test coverage decrease.
+ - Changes that may cause performance degradation. Do regression test when test suites are available.
+ - Discourage **change to the core library** when there is an alternative.
+
+### Issues and Discussions
+* For new issues, write a reply, apply a label if relevant. Ask on discord when necessary. For roadmap issues, add to the roadmap project and encourage community discussion. Mention relevant experts when necessary.
+* For old issues, provide an update or close. Ask on discord when necessary. Encourage PR creation when relevant.
+* Use “good first issue” for easy fix suitable for first-time contributors.
+* Use “task list” for issues that require multiple PRs.
+* For discussions, create an issue when relevant. Discuss on discord when appropriate.
+
## Developing
### Setup
@@ -82,6 +114,14 @@ We have provided the configuration in [devcontainer](https://github.com/microsof
Run `pre-commit install` to install pre-commit into your git hooks. Before you commit, run
`pre-commit run` to check if you meet the pre-commit requirements. If you use Windows (without WSL) and can't commit after installing pre-commit, you can run `pre-commit uninstall` to uninstall the hook. In WSL or Linux this is supposed to work.
+### Write tests
+
+Tests are automatically run via GitHub actions. There are two workflows:
+1. [build.yml](https://github.com/microsoft/autogen/blob/main/.github/workflows/build.yml)
+1. [openai.yml](https://github.com/microsoft/autogen/blob/main/.github/workflows/openai.yml)
+
+The first workflow is required to pass for all PRs. The second workflow is required for changes that affect the openai tests. The second workflow requires approval to run. When writing tests that require openai, please use [`pytest.mark.skipif`](https://github.com/microsoft/autogen/blob/a456b512d5a933ce9707ce51c465ea35a9dd180c/test/test_with_openai.py#L13) to make them run in one python version only when openai is installed. If additional dependency for this test is required, install the dependency in the corresponding python version in [openai.yml](https://github.com/microsoft/autogen/blob/main/.github/workflows/openai.yml).
+
### Coverage
Any code you commit should not decrease coverage. To run all unit tests, install the [test] option:
diff --git a/website/docs/Examples/AutoGen-AgentChat.md b/website/docs/Examples/AutoGen-AgentChat.md
index 1360aa9f5..0a95277d5 100644
--- a/website/docs/Examples/AutoGen-AgentChat.md
+++ b/website/docs/Examples/AutoGen-AgentChat.md
@@ -16,4 +16,6 @@ Links to notebook examples:
* [Automated Complex Task Solving by Group Chat (with 6 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_research.ipynb)
* [Automated Continual Learning from New Data](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_stream.ipynb)
* [Teach Agents New Skills & Reuse via Automated Chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teaching.ipynb)
+* [Teach Agents New Facts, User Preferences and Skills Beyond Coding](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)
* [Automated Code Generation and Question Answering with Retrieval Augemented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb)
+* [Group Chat with Retrieval Augmented Generation (with 5 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_RAG.ipynb)
diff --git a/website/docs/FAQ.md b/website/docs/FAQ.md
index 79d033cfb..ccf214c13 100644
--- a/website/docs/FAQ.md
+++ b/website/docs/FAQ.md
@@ -65,6 +65,8 @@ import autogen
config_list = autogen.config_list_from_models(model_list=["gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"])
```
+> For Azure the model name refers to the OpenAI Studio deployment name.
+
The config list looks like the following, if only OpenAI API key is available:
```python
config_list = [
diff --git a/website/docs/Getting-Started.md b/website/docs/Getting-Started.md
index 131985f02..937e8e513 100644
--- a/website/docs/Getting-Started.md
+++ b/website/docs/Getting-Started.md
@@ -19,7 +19,7 @@ AutoGen is powered by collaborative [research studies](/docs/Research) from Micr
### Quickstart
Install from pip: `pip install pyautogen`. Find more options in [Installation](/docs/Installation).
-For [code execution](https://microsoft.github.io/autogen/FAQ#code-execution), we strongly recommend installing the python docker package, and using docker.
+For [code execution](/docs/FAQ#code-execution), we strongly recommend installing the python docker package, and using docker.
#### Multi-Agent Conversation Framework
Autogen enables the next-gen LLM applications with a generic multi-agent conversation framework. It offers customizable and conversable agents which integrate LLMs, tools and human.
@@ -63,12 +63,13 @@ response = autogen.Completion.create(context=test_instance, **config)
* [Code examples](/docs/Examples/AutoGen-Inference).
* [Documentation](/docs/Use-Cases/enhanced_inference).
-### Where to Go Next?
+### Where to Go Next ?
* Understand the use cases for [multi-agent conversation](/docs/Use-Cases/agent_chat) and [enhanced LLM inference](/docs/Use-Cases/enhanced_inference).
* Find [code examples](/docs/Examples/AutoGen-AgentChat).
* Read [SDK](/docs/reference/agentchat/conversable_agent/).
* Learn about [research](/docs/Research) around AutoGen.
+* [Roadmap](https://github.com/orgs/microsoft/projects/989/views/3)
* Chat on [Discord](https://discord.gg/pAbnFJrkgZ).
* Follow on [Twitter](https://twitter.com/pyautogen).
diff --git a/website/docs/Installation.md b/website/docs/Installation.md
index 8310e5949..9d064d4ab 100644
--- a/website/docs/Installation.md
+++ b/website/docs/Installation.md
@@ -1,5 +1,24 @@
# Installation
+## Setup Virtual Environment
+
+When not using a docker container, we recommend using a virtual environment to install AutoGen. This will ensure that the dependencies for AutoGen are isolated from the rest of your system.
+
+You can create a virtual environment with `venv` as below:
+```bash
+python3 -m venv autogen
+source autogen/bin/activate
+```
+
+Another option is with `Conda`, Conda works better at solving dependency conflicts than pip. You can install it by following [this doc](https://docs.conda.io/projects/conda/en/stable/user-guide/install/index.html),
+and then create a virtual environment as below:
+```bash
+conda create -n autogen python=3.10 # python 3.10 is recommended as it's stable and not too old
+conda activate autogen
+```
+
+Now, you're ready to install AutoGen in the virtual environment you've just created.
+
## Python
AutoGen requires **Python version >= 3.8**. It can be installed from pip:
@@ -24,11 +43,33 @@ pip install docker
```
* blendsearch
+
+AutoGen offers a cost-effective hyperparameter optimization technique [EcoOptiGen](https://arxiv.org/abs/2303.04673) for tuning Large Language Models. Please install with the [blendsearch] option to use it.
```bash
pip install "pyautogen[blendsearch]"
```
+Example notebooks:
+[Optimize for Code Generation](https://github.com/microsoft/autogen/blob/main/notebook/oai_completion.ipynb),
+[Optimize for Math](https://github.com/microsoft/autogen/blob/main/notebook/oai_chatgpt_gpt4.ipynb)
+
* retrievechat
+
+AutoGen supports retrieval-augmented generation tasks such as question answering and code generation with RAG agents. Please install with the [retrievechat] option to use it.
```bash
pip install "pyautogen[retrievechat]"
```
+
+Example notebooks:
+[Automated Code Generation and Question Answering with Retrieval Augmented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb),
+[Group Chat with Retrieval Augmented Generation (with 5 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_RAG.ipynb)
+
+* mathchat
+
+AutoGen offers an experimental agent for math problem solving. Please install with the [mathchat] option to use it.
+```bash
+pip install "pyautogen[mathchat]"
+```
+
+Example notebooks:
+[Using MathChat to Solve Math Problems](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_MathChat.ipynb)
diff --git a/website/docs/Use-Cases/agent_chat.md b/website/docs/Use-Cases/agent_chat.md
index 5f8d22619..4267e85fd 100644
--- a/website/docs/Use-Cases/agent_chat.md
+++ b/website/docs/Use-Cases/agent_chat.md
@@ -11,7 +11,7 @@ AutoGen abstracts and implements conversable agents
designed to solve tasks through inter-agent conversations. Specifically, the agents in AutoGen have the following notable features:
- Conversable: Agents in AutoGen are conversable, which means that any agent can send
-and receive messages from other agents to initiate or continue a conversation
+ and receive messages from other agents to initiate or continue a conversation
- Customizable: Agents in AutoGen can be customized to integrate LLMs, humans, tools, or a combination of them.
@@ -20,7 +20,6 @@ The figure below shows the built-in agents in AutoGen.
We have designed a generic `ConversableAgent` class for Agents that are capable of conversing with each other through the exchange of messages to jointly finish a task. An agent can communicate with other agents and perform actions. Different agents can differ in what actions they perform after receiving messages. Two representative subclasses are `AssistantAgent` and `UserProxyAgent`.
-
- The `AssistantAgent` is designed to act as an AI assistant, using LLMs by default but not requiring human input or code execution. It could write Python code (in a Python coding block) for a user to execute when a message (typically a description of a task that needs to be solved) is received. Under the hood, the Python code is written by LLM (e.g., GPT-4). It can also receive the execution results and suggest corrections or bug fixes. Its behavior can be altered by passing a new system message. The LLM [inference](#enhanced-inference) configuration can be configured via `llm_config`.
- The `UserProxyAgent` is conceptually a proxy agent for humans, soliciting human input as the agent's reply at each interaction turn by default and also having the capability to execute code and call functions. The `UserProxyAgent` triggers code execution automatically when it detects an executable code block in the received message and no human user input is provided. Code execution can be disabled by setting the `code_execution_config` parameter to False. LLM-based response is disabled by default. It can be enabled by setting `llm_config` to a dict corresponding to the [inference](/docs/Use-Cases/enhanced_inference) configuration. When `llm_config` is set as a dictionary, `UserProxyAgent` can generate replies using an LLM when code execution is not performed.
@@ -45,6 +44,7 @@ user_proxy = UserProxyAgent(name="user_proxy")
### A Basic Two-Agent Conversation Example
Once the participating agents are constructed properly, one can start a multi-agent conversation session by an initialization step as shown in the following code:
+
```python
# the assistant receives a message from the user, which contains the task description
user_proxy.initiate_chat(
@@ -52,6 +52,7 @@ user_proxy.initiate_chat(
message="""What date is today? Which big tech stock has the largest year-to-date gain this year? How much is the gain?""",
)
```
+
After the initialization step, the conversation could proceed automatically. Find a visual illustration of how the user_proxy and assistant collaboratively solve the above task autonmously below:

@@ -63,6 +64,7 @@ After the initialization step, the conversation could proceed automatically. Fin
### Supporting Diverse Conversation Patterns
#### Conversations with different levels of autonomy, and human-involvement patterns
+
On the one hand, one can achieve fully autonomous conversations after an initialization step. On the other hand, AutoGen can be used to implement human-in-the-loop problem-solving by configuring human involvement levels and patterns (e.g., setting the `human_input_mode` to `ALWAYS`), as human involvement is expected and/or desired in many applications.
#### Static and dynamic conversations
@@ -72,34 +74,48 @@ By adopting the conversation-driven control with both programming language and n
- Registered auto-reply. With the pluggable auto-reply function, one can choose to invoke conversations with other agents depending on the content of the current message and context. A working system demonstrating this type of dynamic conversation can be found in this code example, demonstrating a [dynamic group chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat.ipynb). In the system, we register an auto-reply function in the group chat manager, which lets LLM decide who the next speaker will be in a group chat setting.
- LLM-based function call. In this approach, LLM decides whether or not to call a particular function depending on the conversation status in each inference call.
-By messaging additional agents in the called functions, the LLM can drive dynamic multi-agent conversation. A working system showcasing this type of dynamic conversation can be found in the [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant would automatically resort to an expert using function calls.
+ By messaging additional agents in the called functions, the LLM can drive dynamic multi-agent conversation. A working system showcasing this type of dynamic conversation can be found in the [multi-user math problem solving scenario](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb), where a student assistant would automatically resort to an expert using function calls.
### Diverse Applications Implemented with AutoGen
-
The figure below shows six examples of applications built using AutoGen.

-* [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)
-* [Auto Code Generation, Execution, Debugging and Human Feedback](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_human_feedback.ipynb)
-* [Solve Tasks Requiring Web Info](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_web_info.ipynb)
-* [Use Provided Tools as Functions](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_function_call.ipynb)
-* [Automated Task Solving with Coding & Planning Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_planning.ipynb)
-* [Automated Task Solving with GPT-4 + Multiple Human Users](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb)
-* [Automated Chess Game Playing & Chitchatting by GPT-4 Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_chess.ipynb)
-* [Automated Task Solving by Group Chat (with 3 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat.ipynb)
-* [Automated Data Visualization by Group Chat (with 3 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_vis.ipynb)
-* [Automated Complex Task Solving by Group Chat (with 6 group member agents and 1 manager agent)](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_research.ipynb)
-* [Automated Continual Learning from New Data](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_stream.ipynb)
-* [Teach Agents New Skills & Reuse via Automated Chat](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teaching.ipynb)
-* [Automated Code Generation and Question Answering with Retrieval Augemented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb)
+1. **Code Generation, Execution, and Debugging**
+ - Automated Task Solving with Code Generation, Execution & Debugging - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)
+ - Auto Code Generation, Execution, Debugging and Human Feedback - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_human_feedback.ipynb)
+ - Automated Code Generation and Question Answering with Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb)
+2. **Multi-Agent Collaboration (>3 Agents)**
+
+ - Automated Task Solving with GPT-4 + Multiple Human Users - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_two_users.ipynb)
+ - Automated Task Solving by Group Chat (with 3 group member agents and 1 manager agent) - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat.ipynb)
+ - Automated Data Visualization by Group Chat (with 3 group member agents and 1 manager agent) - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_vis.ipynb)
+ - Automated Complex Task Solving by Group Chat (with 6 group member agents and 1 manager agent) - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_research.ipynb)
+ - Automated Task Solving with Coding & Planning Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_planning.ipynb)
+
+3. **Applications**
+
+ - Automated Chess Game Playing & Chitchatting by GPT-4 Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_chess.ipynb)
+ - Automated Continual Learning from New Data - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_stream.ipynb)
+
+4. **Tool Use**
+
+ - **Web Search**: Solve Tasks Requiring Web Info - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_web_info.ipynb)
+ - Use Provided Tools as Functions - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_function_call.ipynb)
+ - Task Solving with Langchain Provided Tools as Functions - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_function_call.ipynb)
+ - **RAG**: Group Chat with Retrieval Augmented Generation (with 5 group member agents and 1 manager agent) - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_groupchat_RAG.ipynb)
+ - In-depth Guide to OpenAI Utility Functions - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/oai_openai_utils.ipynb)
+
+5. **Agent Teaching and Learning**
+ - Teach Agents New Skills & Reuse via Automated Chat - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teaching.ipynb)
+ - Teach Agents New Facts, User Preferences and Skills Beyond Coding - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)
## For Further Reading
-*Interested in the research that leads to this package? Please check the following papers.*
+_Interested in the research that leads to this package? Please check the following papers._
-* [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155). Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Shaokun Zhang, Erkang Zhu, Beibin Li, Li Jiang, Xiaoyun Zhang and Chi Wang. ArXiv 2023.
+- [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation Framework](https://arxiv.org/abs/2308.08155). Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Shaokun Zhang, Erkang Zhu, Beibin Li, Li Jiang, Xiaoyun Zhang and Chi Wang. ArXiv 2023.
-* [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).
+- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).
diff --git a/website/yarn.lock b/website/yarn.lock
index 97a133bb2..53ef22fb2 100644
--- a/website/yarn.lock
+++ b/website/yarn.lock
@@ -145,6 +145,14 @@
dependencies:
"@babel/highlight" "^7.18.6"
+"@babel/code-frame@^7.22.13":
+ version "7.22.13"
+ resolved "https://registry.yarnpkg.com/@babel/code-frame/-/code-frame-7.22.13.tgz#e3c1c099402598483b7a8c46a721d1038803755e"
+ integrity sha512-XktuhWlJ5g+3TJXc5upd9Ks1HutSArik6jf2eAjYFyIOf4ej3RN+184cZbzDvbPnuTJIUhPKKJE3cIsYTiAT3w==
+ dependencies:
+ "@babel/highlight" "^7.22.13"
+ chalk "^2.4.2"
+
"@babel/compat-data@^7.17.7", "@babel/compat-data@^7.20.0", "@babel/compat-data@^7.20.1":
version "7.20.1"
resolved "https://registry.npmmirror.com/@babel/compat-data/-/compat-data-7.20.1.tgz#f2e6ef7790d8c8dbf03d379502dcc246dcce0b30"
@@ -193,7 +201,7 @@
json5 "^2.2.1"
semver "^6.3.0"
-"@babel/generator@^7.12.15", "@babel/generator@^7.12.5", "@babel/generator@^7.20.1", "@babel/generator@^7.20.2":
+"@babel/generator@^7.12.15", "@babel/generator@^7.12.5", "@babel/generator@^7.20.2":
version "7.20.2"
resolved "https://registry.npmmirror.com/@babel/generator/-/generator-7.20.2.tgz#c2e89e22613a039285c1e7b749e2cd0b30b9a481"
integrity sha512-SD75PMIK6i9H8G/tfGvB4KKl4Nw6Ssos9nGgYwxbgyTP0iX/Z55DveoH86rmUB/YHTQQ+ZC0F7xxaY8l2OF44Q==
@@ -202,6 +210,16 @@
"@jridgewell/gen-mapping" "^0.3.2"
jsesc "^2.5.1"
+"@babel/generator@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/generator/-/generator-7.23.0.tgz#df5c386e2218be505b34837acbcb874d7a983420"
+ integrity sha512-lN85QRR+5IbYrMWM6Y4pE/noaQtg4pNiqeNGX60eqOfo6gtEj6uw/JagelB8vVztSd7R6M5n1+PQkDbHbBRU4g==
+ dependencies:
+ "@babel/types" "^7.23.0"
+ "@jridgewell/gen-mapping" "^0.3.2"
+ "@jridgewell/trace-mapping" "^0.3.17"
+ jsesc "^2.5.1"
+
"@babel/helper-annotate-as-pure@^7.18.6":
version "7.18.6"
resolved "https://registry.npmmirror.com/@babel/helper-annotate-as-pure/-/helper-annotate-as-pure-7.18.6.tgz#eaa49f6f80d5a33f9a5dd2276e6d6e451be0a6bb"
@@ -265,6 +283,11 @@
resolved "https://registry.npmmirror.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.18.9.tgz#0c0cee9b35d2ca190478756865bb3528422f51be"
integrity sha512-3r/aACDJ3fhQ/EVgFy0hpj8oHyHpQc+LPtJoY9SzTThAsStm4Ptegq92vqKoE3vD706ZVFWITnMnxucw+S9Ipg==
+"@babel/helper-environment-visitor@^7.22.20":
+ version "7.22.20"
+ resolved "https://registry.yarnpkg.com/@babel/helper-environment-visitor/-/helper-environment-visitor-7.22.20.tgz#96159db61d34a29dba454c959f5ae4a649ba9167"
+ integrity sha512-zfedSIzFhat/gFhWfHtgWvlec0nqB9YEIVrpuwjruLlXfUSnA8cJB0miHKwqDnQ7d32aKo2xt88/xZptwxbfhA==
+
"@babel/helper-explode-assignable-expression@^7.18.6":
version "7.18.6"
resolved "https://registry.npmmirror.com/@babel/helper-explode-assignable-expression/-/helper-explode-assignable-expression-7.18.6.tgz#41f8228ef0a6f1a036b8dfdfec7ce94f9a6bc096"
@@ -280,6 +303,14 @@
"@babel/template" "^7.18.10"
"@babel/types" "^7.19.0"
+"@babel/helper-function-name@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/helper-function-name/-/helper-function-name-7.23.0.tgz#1f9a3cdbd5b2698a670c30d2735f9af95ed52759"
+ integrity sha512-OErEqsrxjZTJciZ4Oo+eoZqeW9UIiOcuYKRJA4ZAgV9myA+pOXhhmpfNCKjEH/auVfEYVFJ6y1Tc4r0eIApqiw==
+ dependencies:
+ "@babel/template" "^7.22.15"
+ "@babel/types" "^7.23.0"
+
"@babel/helper-hoist-variables@^7.18.6":
version "7.18.6"
resolved "https://registry.npmmirror.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.18.6.tgz#d4d2c8fb4baeaa5c68b99cc8245c56554f926678"
@@ -287,6 +318,13 @@
dependencies:
"@babel/types" "^7.18.6"
+"@babel/helper-hoist-variables@^7.22.5":
+ version "7.22.5"
+ resolved "https://registry.yarnpkg.com/@babel/helper-hoist-variables/-/helper-hoist-variables-7.22.5.tgz#c01a007dac05c085914e8fb652b339db50d823bb"
+ integrity sha512-wGjk9QZVzvknA6yKIUURb8zY3grXCcOZt+/7Wcy8O2uctxhplmUPkOdlgoNhmdVee2c92JXbf1xpMtVNbfoxRw==
+ dependencies:
+ "@babel/types" "^7.22.5"
+
"@babel/helper-member-expression-to-functions@^7.18.9":
version "7.18.9"
resolved "https://registry.npmmirror.com/@babel/helper-member-expression-to-functions/-/helper-member-expression-to-functions-7.18.9.tgz#1531661e8375af843ad37ac692c132841e2fd815"
@@ -374,16 +412,33 @@
dependencies:
"@babel/types" "^7.18.6"
+"@babel/helper-split-export-declaration@^7.22.6":
+ version "7.22.6"
+ resolved "https://registry.yarnpkg.com/@babel/helper-split-export-declaration/-/helper-split-export-declaration-7.22.6.tgz#322c61b7310c0997fe4c323955667f18fcefb91c"
+ integrity sha512-AsUnxuLhRYsisFiaJwvp1QF+I3KjD5FOxut14q/GzovUe6orHLesW2C7d754kRm53h5gqrz6sFl6sxc4BVtE/g==
+ dependencies:
+ "@babel/types" "^7.22.5"
+
"@babel/helper-string-parser@^7.19.4":
version "7.19.4"
resolved "https://registry.npmmirror.com/@babel/helper-string-parser/-/helper-string-parser-7.19.4.tgz#38d3acb654b4701a9b77fb0615a96f775c3a9e63"
integrity sha512-nHtDoQcuqFmwYNYPz3Rah5ph2p8PFeFCsZk9A/48dPc/rGocJ5J3hAAZ7pb76VWX3fZKu+uEr/FhH5jLx7umrw==
+"@babel/helper-string-parser@^7.22.5":
+ version "7.22.5"
+ resolved "https://registry.yarnpkg.com/@babel/helper-string-parser/-/helper-string-parser-7.22.5.tgz#533f36457a25814cf1df6488523ad547d784a99f"
+ integrity sha512-mM4COjgZox8U+JcXQwPijIZLElkgEpO5rsERVDJTc2qfCDfERyob6k5WegS14SX18IIjv+XD+GrqNumY5JRCDw==
+
"@babel/helper-validator-identifier@^7.18.6", "@babel/helper-validator-identifier@^7.19.1":
version "7.19.1"
resolved "https://registry.npmmirror.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.19.1.tgz#7eea834cf32901ffdc1a7ee555e2f9c27e249ca2"
integrity sha512-awrNfaMtnHUr653GgGEs++LlAvW6w+DcPrOliSMXWCKo597CwL5Acf/wWdNkf/tfEQE3mjkeD1YOVZOUV/od1w==
+"@babel/helper-validator-identifier@^7.22.20":
+ version "7.22.20"
+ resolved "https://registry.yarnpkg.com/@babel/helper-validator-identifier/-/helper-validator-identifier-7.22.20.tgz#c4ae002c61d2879e724581d96665583dbc1dc0e0"
+ integrity sha512-Y4OZ+ytlatR8AI+8KZfKuL5urKp7qey08ha31L8b3BwewJAoJamTzyvxPR/5D+KkdJCGPq/+8TukHBlY10FX9A==
+
"@babel/helper-validator-option@^7.18.6":
version "7.18.6"
resolved "https://registry.npmmirror.com/@babel/helper-validator-option/-/helper-validator-option-7.18.6.tgz#bf0d2b5a509b1f336099e4ff36e1a63aa5db4db8"
@@ -417,11 +472,25 @@
chalk "^2.0.0"
js-tokens "^4.0.0"
-"@babel/parser@^7.12.16", "@babel/parser@^7.12.7", "@babel/parser@^7.18.10", "@babel/parser@^7.20.1", "@babel/parser@^7.20.2":
+"@babel/highlight@^7.22.13":
+ version "7.22.20"
+ resolved "https://registry.yarnpkg.com/@babel/highlight/-/highlight-7.22.20.tgz#4ca92b71d80554b01427815e06f2df965b9c1f54"
+ integrity sha512-dkdMCN3py0+ksCgYmGG8jKeGA/8Tk+gJwSYYlFGxG5lmhfKNoAy004YpLxpS1W2J8m/EK2Ew+yOs9pVRwO89mg==
+ dependencies:
+ "@babel/helper-validator-identifier" "^7.22.20"
+ chalk "^2.4.2"
+ js-tokens "^4.0.0"
+
+"@babel/parser@^7.12.16", "@babel/parser@^7.12.7", "@babel/parser@^7.18.10", "@babel/parser@^7.20.2":
version "7.20.2"
resolved "https://registry.npmmirror.com/@babel/parser/-/parser-7.20.2.tgz#9aeb9b92f64412b5f81064d46f6a1ac0881337f4"
integrity sha512-afk318kh2uKbo7BEj2QtEi8HVCGrwHUffrYDy7dgVcSa2j9lY3LDjPzcyGdpX7xgm35aWqvciZJ4WKmdF/SxYg==
+"@babel/parser@^7.22.15", "@babel/parser@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.23.0.tgz#da950e622420bf96ca0d0f2909cdddac3acd8719"
+ integrity sha512-vvPKKdMemU85V9WE/l5wZEmImpCtLqbnTvqDS2U1fJ96KrxoW7KrXhNsNCblQlg8Ck4b85yxdTyelsMUgFUXiw==
+
"@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression@^7.18.6":
version "7.18.6"
resolved "https://registry.npmmirror.com/@babel/plugin-bugfix-safari-id-destructuring-collision-in-function-expression/-/plugin-bugfix-safari-id-destructuring-collision-in-function-expression-7.18.6.tgz#da5b8f9a580acdfbe53494dba45ea389fb09a4d2"
@@ -1150,19 +1219,28 @@
"@babel/parser" "^7.18.10"
"@babel/types" "^7.18.10"
-"@babel/traverse@^7.12.13", "@babel/traverse@^7.12.9", "@babel/traverse@^7.19.0", "@babel/traverse@^7.19.1", "@babel/traverse@^7.20.1":
- version "7.20.1"
- resolved "https://registry.npmmirror.com/@babel/traverse/-/traverse-7.20.1.tgz#9b15ccbf882f6d107eeeecf263fbcdd208777ec8"
- integrity sha512-d3tN8fkVJwFLkHkBN479SOsw4DMZnz8cdbL/gvuDuzy3TS6Nfw80HuQqhw1pITbIruHyh7d1fMA47kWzmcUEGA==
+"@babel/template@^7.22.15":
+ version "7.22.15"
+ resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.22.15.tgz#09576efc3830f0430f4548ef971dde1350ef2f38"
+ integrity sha512-QPErUVm4uyJa60rkI73qneDacvdvzxshT3kksGqlGWYdOTIUOwJ7RDUL8sGqslY1uXWSL6xMFKEXDS3ox2uF0w==
dependencies:
- "@babel/code-frame" "^7.18.6"
- "@babel/generator" "^7.20.1"
- "@babel/helper-environment-visitor" "^7.18.9"
- "@babel/helper-function-name" "^7.19.0"
- "@babel/helper-hoist-variables" "^7.18.6"
- "@babel/helper-split-export-declaration" "^7.18.6"
- "@babel/parser" "^7.20.1"
- "@babel/types" "^7.20.0"
+ "@babel/code-frame" "^7.22.13"
+ "@babel/parser" "^7.22.15"
+ "@babel/types" "^7.22.15"
+
+"@babel/traverse@^7.12.13", "@babel/traverse@^7.12.9", "@babel/traverse@^7.19.0", "@babel/traverse@^7.19.1", "@babel/traverse@^7.20.1":
+ version "7.23.2"
+ resolved "https://registry.yarnpkg.com/@babel/traverse/-/traverse-7.23.2.tgz#329c7a06735e144a506bdb2cad0268b7f46f4ad8"
+ integrity sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==
+ dependencies:
+ "@babel/code-frame" "^7.22.13"
+ "@babel/generator" "^7.23.0"
+ "@babel/helper-environment-visitor" "^7.22.20"
+ "@babel/helper-function-name" "^7.23.0"
+ "@babel/helper-hoist-variables" "^7.22.5"
+ "@babel/helper-split-export-declaration" "^7.22.6"
+ "@babel/parser" "^7.23.0"
+ "@babel/types" "^7.23.0"
debug "^4.1.0"
globals "^11.1.0"
@@ -1175,6 +1253,15 @@
"@babel/helper-validator-identifier" "^7.19.1"
to-fast-properties "^2.0.0"
+"@babel/types@^7.22.15", "@babel/types@^7.22.5", "@babel/types@^7.23.0":
+ version "7.23.0"
+ resolved "https://registry.yarnpkg.com/@babel/types/-/types-7.23.0.tgz#8c1f020c9df0e737e4e247c0619f58c68458aaeb"
+ integrity sha512-0oIyUfKoI3mSqMvsxBdclDwxXKXAUA8v/apZbc+iSyARYou1o8ZGDxbUYyLFoW2arqS2jDGqJuZvv1d/io1axg==
+ dependencies:
+ "@babel/helper-string-parser" "^7.22.5"
+ "@babel/helper-validator-identifier" "^7.22.20"
+ to-fast-properties "^2.0.0"
+
"@docsearch/css@3.3.0":
version "3.3.0"
resolved "https://registry.npmmirror.com/@docsearch/css/-/css-3.3.0.tgz#d698e48302d12240d7c2f7452ccb2d2239a8cd80"
@@ -1648,6 +1735,11 @@
resolved "https://registry.npmmirror.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz#2203b118c157721addfe69d47b70465463066d78"
integrity sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w==
+"@jridgewell/resolve-uri@^3.1.0":
+ version "3.1.1"
+ resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz#c08679063f279615a3326583ba3a90d1d82cc721"
+ integrity sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==
+
"@jridgewell/set-array@^1.0.0", "@jridgewell/set-array@^1.0.1":
version "1.1.2"
resolved "https://registry.npmmirror.com/@jridgewell/set-array/-/set-array-1.1.2.tgz#7c6cf998d6d20b914c0a55a91ae928ff25965e72"
@@ -1666,6 +1758,11 @@
resolved "https://registry.npmmirror.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz#add4c98d341472a289190b424efbdb096991bb24"
integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==
+"@jridgewell/sourcemap-codec@^1.4.14":
+ version "1.4.15"
+ resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz#d7c6e6755c78567a951e04ab52ef0fd26de59f32"
+ integrity sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==
+
"@jridgewell/trace-mapping@^0.3.14", "@jridgewell/trace-mapping@^0.3.9":
version "0.3.17"
resolved "https://registry.npmmirror.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.17.tgz#793041277af9073b0951a7fe0f0d8c4c98c36985"
@@ -1674,6 +1771,14 @@
"@jridgewell/resolve-uri" "3.1.0"
"@jridgewell/sourcemap-codec" "1.4.14"
+"@jridgewell/trace-mapping@^0.3.17":
+ version "0.3.20"
+ resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.20.tgz#72e45707cf240fa6b081d0366f8265b0cd10197f"
+ integrity sha512-R8LcPeWZol2zR8mmH3JeKQ6QRCFb7XgUhV9ZlGhHLGyg4wpPiPZNQOOWhFZhxKw8u//yTbNGI42Bx/3paXEQ+Q==
+ dependencies:
+ "@jridgewell/resolve-uri" "^3.1.0"
+ "@jridgewell/sourcemap-codec" "^1.4.14"
+
"@leichtgewicht/ip-codec@^2.0.1":
version "2.0.4"
resolved "https://registry.npmmirror.com/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz#b2ac626d6cb9c8718ab459166d4bb405b8ffa78b"