From 50ac5476377c1b41330589a6cfc5c4e65b93079f Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 25 Oct 2023 10:38:43 +0530 Subject: [PATCH] feat: Qdrant vector store support (#303) * feat: QdrantRetrieveUserProxyAgent * fix: QdrantRetrieveUserProxyAgent docstring * chore: batch of 500 all CPU cores * chore: conditional import for tests * chore: config parallel, batch 100 * chore: collection creation params * chore: conditonal payload indexing fastembed import check * docs: notebook for QdrantRetrieveUserProxyAgent * docs: update docs link * docs: notebook examples update * chore: hnsw, payload index reference * docs: notebook docs_path update * Update test/agentchat/test_qdrant_retrievechat.py Co-authored-by: Li Jiang * chore: update notebook output * Fix format --------- Co-authored-by: Li Jiang --- .../qdrant_retrieve_user_proxy_agent.py | 266 ++++ notebook/agentchat_qdrant_RetrieveChat.ipynb | 1234 +++++++++++++++++ test/agentchat/test_qdrant_retrievechat.py | 102 ++ website/docs/Examples/AutoGen-AgentChat.md | 2 + 4 files changed, 1604 insertions(+) create mode 100644 autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py create mode 100644 notebook/agentchat_qdrant_RetrieveChat.ipynb create mode 100644 test/agentchat/test_qdrant_retrievechat.py diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py new file mode 100644 index 000000000..b348b07e0 --- /dev/null +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -0,0 +1,266 @@ +from typing import Callable, Dict, List, Optional + +from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent +from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks +import logging + +logger = logging.getLogger(__name__) + +try: + from qdrant_client import QdrantClient, models + from qdrant_client.fastembed_common import QueryResponse + import fastembed +except ImportError as e: + logging.fatal("Failed to import qdrant_client with fastembed. Try running 'pip install qdrant_client[fastembed]'") + raise e + + +class QdrantRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def __init__( + self, + name="RetrieveChatAgent", + human_input_mode: str | None = "ALWAYS", + is_termination_msg: Callable[[Dict], bool] | None = None, + retrieve_config: Dict | None = None, + **kwargs, + ): + """ + Args: + name (str): name of the agent. + human_input_mode (str): whether to ask for human inputs every time a message is received. + Possible values are "ALWAYS", "TERMINATE", "NEVER". + (1) When "ALWAYS", the agent prompts for human input every time a message is received. + Under this mode, the conversation stops when the human input is "exit", + or when is_termination_msg is True and there is no human input. + (2) When "TERMINATE", the agent only prompts for human input only when a termination message is received or + the number of auto reply reaches the max_consecutive_auto_reply. + (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops + when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". + retrieve_config (dict or None): config for the retrieve agent. + To use default config, set to None. Otherwise, set to a dictionary with the following keys: + - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System + prompt will be different for different tasks. The default value is `default`, which supports both code and qa. + - client (Optional, qdrant_client.QdrantClient(":memory:")): A QdrantClient instance. If not provided, an in-memory instance will be assigned. Not recommended for production. + will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. + - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, + or the url to a single file. Default is None, which works only if the collection is already created. + - collection_name (Optional, str): the name of the collection. + If key not provided, a default name `autogen-docs` will be used. + - model (Optional, str): the model to use for the retrieve chat. + If key not provided, a default model `gpt-4` will be used. + - chunk_token_size (Optional, int): the chunk token size for the retrieve chat. + If key not provided, a default size `max_tokens * 0.4` will be used. + - context_max_tokens (Optional, int): the context max token size for the retrieve chat. + If key not provided, a default size `max_tokens * 0.8` will be used. + - chunk_mode (Optional, str): the chunk mode for the retrieve chat. Possible values are + "multi_lines" and "one_line". If key not provided, a default mode `multi_lines` will be used. + - must_break_at_empty_line (Optional, bool): chunk will only break at empty line if True. Default is True. + If chunk_mode is "one_line", this parameter will be ignored. + - embedding_model (Optional, str): the embedding model to use for the retrieve chat. + If key not provided, a default model `BAAI/bge-small-en-v1.5` will be used. All available models + can be found at `https://qdrant.github.io/fastembed/examples/Supported_Models/`. + - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. + - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". + If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. + - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. + - custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string. + The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name). + Default is None, tiktoken will be used and may not be accurate for non-OpenAI models. + - custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings. + Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`. + - parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores. + - on_disk (Optional, bool): Whether to store the collection on disk. Default is False. + - quantization_config: Quantization configuration. If None, quantization will be disabled. + - hnsw_config: HNSW configuration. If None, default configuration will be used. + You can find more info about the hnsw configuration options at https://qdrant.tech/documentation/concepts/indexing/#vector-index. + API Reference: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection + - payload_indexing: Whether to create a payload index for the document field. Default is False. + You can find more info about the payload indexing options at https://qdrant.tech/documentation/concepts/indexing/#payload-index + API Reference: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_field_index + **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). + + """ + super().__init__(name, human_input_mode, is_termination_msg, retrieve_config, **kwargs) + self._client = self._retrieve_config.get("client", QdrantClient(":memory:")) + self._embedding_model = self._retrieve_config.get("embedding_model", "BAAI/bge-small-en-v1.5") + # Uses all available CPU cores to encode data when set to 0 + self._parallel = self._retrieve_config.get("parallel", 0) + self._on_disk = self._retrieve_config.get("on_disk", False) + self._quantization_config = self._retrieve_config.get("quantization_config", None) + self._hnsw_config = self._retrieve_config.get("hnsw_config", None) + self._payload_indexing = self._retrieve_config.get("payload_indexing", False) + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + """ + Args: + problem (str): the problem to be solved. + n_results (int): the number of results to be retrieved. + search_string (str): only docs containing this string will be retrieved. + """ + if not self._collection: + print("Trying to create collection.") + create_qdrant_from_dir( + dir_path=self._docs_path, + max_tokens=self._chunk_token_size, + client=self._client, + collection_name=self._collection_name, + chunk_mode=self._chunk_mode, + must_break_at_empty_line=self._must_break_at_empty_line, + embedding_model=self._embedding_model, + custom_text_split_function=self.custom_text_split_function, + parallel=self._parallel, + on_disk=self._on_disk, + quantization_config=self._quantization_config, + hnsw_config=self._hnsw_config, + payload_indexing=self._payload_indexing, + ) + self._collection = True + + results = query_qdrant( + query_texts=problem, + n_results=n_results, + search_string=search_string, + client=self._client, + collection_name=self._collection_name, + embedding_model=self._embedding_model, + ) + self._results = results + + +def create_qdrant_from_dir( + dir_path: str, + max_tokens: int = 4000, + client: QdrantClient = None, + collection_name: str = "all-my-documents", + chunk_mode: str = "multi_lines", + must_break_at_empty_line: bool = True, + embedding_model: str = "BAAI/bge-small-en-v1.5", + custom_text_split_function: Callable = None, + parallel: int = 0, + on_disk: bool = False, + quantization_config: Optional[models.QuantizationConfig] = None, + hnsw_config: Optional[models.HnswConfigDiff] = None, + payload_indexing: bool = False, + qdrant_client_options: Optional[Dict] = {}, +): + """Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a url to + a single file. + + Args: + dir_path (str): the path to the directory, file or url. + max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000. + client (Optional, QdrantClient): the QdrantClient instance. Default is None. + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". + chunk_mode (Optional, str): the chunk mode. Default is "multi_lines". + must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True. + embedding_model (Optional, str): the embedding model to use. Default is "BAAI/bge-small-en-v1.5". The list of all the available models can be at https://qdrant.github.io/fastembed/examples/Supported_Models/. + parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores + on_disk (Optional, bool): Whether to store the collection on disk. Default is False. + quantization_config: Quantization configuration. If None, quantization will be disabled. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection + hnsw_config: HNSW configuration. If None, default configuration will be used. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection + payload_indexing: Whether to create a payload index for the document field. Default is False. + qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client. Reference: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58. + """ + if client is None: + client = QdrantClient(**qdrant_client_options) + client.set_model(embedding_model) + + if custom_text_split_function is not None: + chunks = split_files_to_chunks( + get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function + ) + else: + chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) + logger.info(f"Found {len(chunks)} chunks.") + + # Check if collection by same name exists, if not, create it with custom options + try: + client.get_collection(collection_name=collection_name) + except Exception: + client.create_collection( + collection_name=collection_name, + vectors_config=client.get_fastembed_vector_params( + on_disk=on_disk, quantization_config=quantization_config, hnsw_config=hnsw_config + ), + ) + client.get_collection(collection_name=collection_name) + + # Upsert in batch of 100 or less if the total number of chunks is less than 100 + for i in range(0, len(chunks), min(100, len(chunks))): + end_idx = i + min(100, len(chunks) - i) + client.add(collection_name, documents=chunks[i:end_idx], ids=[j for j in range(i, end_idx)], parallel=parallel) + + # Create a payload index for the document field + # Enables highly efficient payload filtering. Reference: https://qdrant.tech/documentation/concepts/indexing/#indexing + # Creating an index requires additional computational resources and memory. + # If filtering performance is critical, we can consider creating an index. + if payload_indexing: + client.create_payload_index( + collection_name=collection_name, + field_name="document", + field_schema=models.TextIndexParams( + type="text", + tokenizer=models.TokenizerType.WORD, + min_token_len=2, + max_token_len=15, + ), + ) + + +def query_qdrant( + query_texts: List[str], + n_results: int = 10, + client: QdrantClient = None, + collection_name: str = "all-my-documents", + search_string: str = "", + embedding_model: str = "BAAI/bge-small-en-v1.5", + qdrant_client_options: Optional[Dict] = {}, +) -> List[List[QueryResponse]]: + """Perform a similarity search with filters on a Qdrant collection + + Args: + query_texts (List[str]): the query texts. + n_results (Optional, int): the number of results to return. Default is 10. + client (Optional, API): the QdrantClient instance. A default in-memory client will be instantiated if None. + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". + search_string (Optional, str): the search string. Default is "". + embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if embedding_function is not None. + qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client. Reference: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58. + + Returns: + List[List[QueryResponse]]: the query result. The format is: + class QueryResponse(BaseModel, extra="forbid"): # type: ignore + id: Union[str, int] + embedding: Optional[List[float]] + metadata: Dict[str, Any] + document: str + score: float + """ + if client is None: + client = QdrantClient(**qdrant_client_options) + client.set_model(embedding_model) + + results = client.query_batch( + collection_name, + query_texts, + limit=n_results, + query_filter=models.Filter( + must=[ + models.FieldCondition( + key="document", + match=models.MatchText(text=search_string), + ) + ] + ) + if search_string + else None, + ) + + data = { + "ids": [[result.id for result in sublist] for sublist in results], + "documents": [[result.document for result in sublist] for sublist in results], + } + return data diff --git a/notebook/agentchat_qdrant_RetrieveChat.ipynb b/notebook/agentchat_qdrant_RetrieveChat.ipynb new file mode 100644 index 000000000..42a5cf82f --- /dev/null +++ b/notebook/agentchat_qdrant_RetrieveChat.ipynb @@ -0,0 +1,1234 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "# Using RetrieveChat with Qdrant for Retrieve Augmented Code Generation and Question Answering\n", + "\n", + "[Qdrant](https://qdrant.tech/) is a high-performance vector search engine/database.\n", + "\n", + "This notebook demonstrates the usage of `QdrantRetrieveUserProxyAgent` for RAG, based on [agentchat_RetrieveChat.ipynb](https://colab.research.google.com/github/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb).\n", + "\n", + "\n", + "RetrieveChat is a conversational system for retrieve augmented code generation and question answering. In this notebook, we demonstrate how to utilize RetrieveChat to generate code and answer questions based on customized documentations that are not present in the LLM's training dataset. RetrieveChat uses the `RetrieveAssistantAgent` and `QdrantRetrieveUserProxyAgent`, which is similar to the usage of `AssistantAgent` and `UserProxyAgent` in other notebooks (e.g., [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)).\n", + "\n", + "We'll demonstrate usage of RetrieveChat with Qdrant for code generation and question answering w/ human feedback.\n", + "\n", + "\n", + "## Requirements\n", + "\n", + "AutoGen requires `Python>=3.8`. To run this notebook example, please install the [retrievechat] option.\n", + "```bash\n", + "pip install \"pyautogen[retrievechat] flaml[automl] qdrant_client[fastembed]\"\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set your API Endpoint\n", + "\n", + "The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "models to use: ['gpt-3.5-turbo']\n" + ] + } + ], + "source": [ + "import autogen\n", + "\n", + "config_list = autogen.config_list_from_json(\n", + " env_or_file=\"OAI_CONFIG_LIST\",\n", + " file_location=\".\",\n", + " filter_dict={\n", + " \"model\": {\n", + " \"gpt-4\",\n", + " \"gpt4\",\n", + " \"gpt-4-32k\",\n", + " \"gpt-4-32k-0314\",\n", + " \"gpt-35-turbo\",\n", + " \"gpt-3.5-turbo\",\n", + " }\n", + " },\n", + ")\n", + "\n", + "assert len(config_list) > 0\n", + "print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well). Only the gpt-4 and gpt-3.5-turbo models are kept in the list based on the filter condition.\n", + "\n", + "The config list looks like the following:\n", + "```python\n", + "config_list = [\n", + " {\n", + " 'model': 'gpt-4',\n", + " 'api_key': '',\n", + " },\n", + " {\n", + " 'model': 'gpt-4',\n", + " 'api_key': '',\n", + " 'api_base': '',\n", + " 'api_type': 'azure',\n", + " 'api_version': '2023-06-01-preview',\n", + " },\n", + " {\n", + " 'model': 'gpt-3.5-turbo',\n", + " 'api_key': '',\n", + " 'api_base': '',\n", + " 'api_type': 'azure',\n", + " 'api_version': '2023-06-01-preview',\n", + " },\n", + "]\n", + "```\n", + "\n", + "If you open this notebook in colab, you can upload your files by clicking the file icon on the left panel and then choose \"upload file\" icon.\n", + "\n", + "You can set the value of config_list in other ways you prefer, e.g., loading from a YAML file." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accepted file formats for `docs_path`:\n", + "['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n" + ] + } + ], + "source": [ + "# Accepted file formats for that can be stored in \n", + "# a vector database instance\n", + "from autogen.retrieve_utils import TEXT_FORMATS\n", + "\n", + "print(\"Accepted file formats for `docs_path`:\")\n", + "print(TEXT_FORMATS)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Construct agents for RetrieveChat\n", + "\n", + "We start by initialzing the `RetrieveAssistantAgent` and `QdrantRetrieveUserProxyAgent`. The system message needs to be set to \"You are a helpful assistant.\" for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `QdrantRetrieveUserProxyAgent.generate_init_prompt` to combine the instructions and a retrieval augmented generation task for an initial prompt to be sent to the LLM assistant.\n", + "\n", + "### You can find the list of all the embedding models supported by Qdrant [here](https://qdrant.github.io/fastembed/examples/Supported_Models/)." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n", + "from autogen.agentchat.contrib.qdrant_retrieve_user_proxy_agent import QdrantRetrieveUserProxyAgent\n", + "from qdrant_client import QdrantClient\n", + "\n", + "autogen.ChatCompletion.start_logging()\n", + "\n", + "# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n", + "assistant = RetrieveAssistantAgent(\n", + " name=\"assistant\", \n", + " system_message=\"You are a helpful assistant.\",\n", + " llm_config={\n", + " \"request_timeout\": 600,\n", + " \"seed\": 42,\n", + " \"config_list\": config_list,\n", + " },\n", + ")\n", + "\n", + "# 2. create the QdrantRetrieveUserProxyAgent instance named \"ragproxyagent\"\n", + "# By default, the human_input_mode is \"ALWAYS\", which means the agent will ask for human input at every step. We set it to \"NEVER\" here.\n", + "# `docs_path` is the path to the docs directory. It can also be the path to a single file, or the url to a single file. By default, \n", + "# it is set to None, which works only if the collection is already created.\n", + "# \n", + "# Here we generated the documentations from FLAML's docstrings. Not needed if you just want to try this notebook but not to reproduce the\n", + "# outputs. Clone the FLAML (https://github.com/microsoft/FLAML) repo and navigate to its website folder. Pip install and run `pydoc-markdown`\n", + "# and it will generate folder `reference` under `website/docs`.\n", + "#\n", + "# `task` indicates the kind of task we're working on. In this example, it's a `code` task.\n", + "# `chunk_token_size` is the chunk token size for the retrieve chat. By default, it is set to `max_tokens * 0.6`, here we set it to 2000.\n", + "# We use an in-memory QdrantClient instance here. Not recommended for production.\n", + "# Get the installation instructions here: https://qdrant.tech/documentation/guides/installation/\n", + "ragproxyagent = QdrantRetrieveUserProxyAgent(\n", + " name=\"ragproxyagent\",\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=10,\n", + " retrieve_config={\n", + " \"task\": \"code\",\n", + " \"docs_path\": \"~/path/to/FLAML/website/docs/reference\", # change this to your own path, such as https://raw.githubusercontent.com/microsoft/autogen/main/README.md\n", + " \"chunk_token_size\": 2000,\n", + " \"model\": config_list[0][\"model\"],\n", + " \"client\": QdrantClient(\":memory:\"),\n", + " \"embedding_model\": \"BAAI/bge-small-en-v1.5\",\n", + " },\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Example 1\n", + "\n", + "[back to top](#toc)\n", + "\n", + "Use RetrieveChat to answer a question and ask for human-in-loop feedbacks.\n", + "\n", + "Problem: Is there a function named `tune_automl` in FLAML?" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32mAdding doc_id 69 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 47 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 64 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 65 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 21 to context.\u001b[0m\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", + "\n", + "User's question is: Is there a function called tune_automl?\n", + "\n", + "Context is: {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " \"reference/autogen/agentchat/contrib/math_user_proxy_agent\",\n", + " \"reference/autogen/agentchat/contrib/retrieve_assistant_agent\",\n", + " \"reference/autogen/agentchat/contrib/retrieve_user_proxy_agent\"\n", + " ],\n", + " \"label\": \"autogen.agentchat.contrib\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/autogen/agentchat/agent\",\n", + " \"reference/autogen/agentchat/assistant_agent\",\n", + " \"reference/autogen/agentchat/conversable_agent\",\n", + " \"reference/autogen/agentchat/groupchat\",\n", + " \"reference/autogen/agentchat/user_proxy_agent\"\n", + " ],\n", + " \"label\": \"autogen.agentchat\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/autogen/oai/completion\",\n", + " \"reference/autogen/oai/openai_utils\"\n", + " ],\n", + " \"label\": \"autogen.oai\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/autogen/code_utils\",\n", + " \"reference/autogen/math_utils\",\n", + " \"reference/autogen/retrieve_utils\"\n", + " ],\n", + " \"label\": \"autogen\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/nlp/huggingface/trainer\",\n", + " \"reference/automl/nlp/huggingface/training_args\",\n", + " \"reference/automl/nlp/huggingface/utils\"\n", + " ],\n", + " \"label\": \"automl.nlp.huggingface\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/automl/nlp/utils\"\n", + " ],\n", + " \"label\": \"automl.nlp\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/spark/metrics\",\n", + " \"reference/automl/spark/utils\"\n", + " ],\n", + " \"label\": \"automl.spark\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/task/task\",\n", + " \"reference/automl/task/time_series_task\"\n", + " ],\n", + " \"label\": \"automl.task\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/time_series/sklearn\",\n", + " \"reference/automl/time_series/tft\",\n", + " \"reference/automl/time_series/ts_data\",\n", + " \"reference/automl/time_series/ts_model\"\n", + " ],\n", + " \"label\": \"automl.time_series\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/automl/automl\",\n", + " \"reference/automl/data\",\n", + " \"reference/automl/ml\",\n", + " \"reference/automl/model\",\n", + " \"reference/automl/state\"\n", + " ],\n", + " \"label\": \"automl\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/default/estimator\",\n", + " \"reference/default/greedy\",\n", + " \"reference/default/portfolio\",\n", + " \"reference/default/suggest\"\n", + " ],\n", + " \"label\": \"default\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/onlineml/autovw\",\n", + " \"reference/onlineml/trial\",\n", + " \"reference/onlineml/trial_runner\"\n", + " ],\n", + " \"label\": \"onlineml\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " \"reference/tune/scheduler/online_scheduler\",\n", + " \"reference/tune/scheduler/trial_scheduler\"\n", + " ],\n", + " \"label\": \"tune.scheduler\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/tune/searcher/blendsearch\",\n", + " \"reference/tune/searcher/cfo_cat\",\n", + " \"reference/tune/searcher/flow2\",\n", + " \"reference/tune/searcher/online_searcher\",\n", + " \"reference/tune/searcher/search_thread\",\n", + " \"reference/tune/searcher/suggestion\",\n", + " \"reference/tune/searcher/variant_generator\"\n", + " ],\n", + " \"label\": \"tune.searcher\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/tune/spark/utils\"\n", + " ],\n", + " \"label\": \"tune.spark\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/tune/analysis\",\n", + " \"reference/tune/sample\",\n", + " \"reference/tune/space\",\n", + " \"reference/tune/trial\",\n", + " \"reference/tune/trial_runner\",\n", + " \"reference/tune/tune\",\n", + " \"reference/tune/utils\"\n", + " ],\n", + " \"label\": \"tune\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/config\"\n", + " ],\n", + " \"label\": \"Reference\",\n", + " \"type\": \"category\"\n", + "}\n", + "---\n", + "sidebar_label: config\n", + "title: config\n", + "---\n", + "\n", + "!\n", + "* Copyright (c) Microsoft Corporation. All rights reserved.\n", + "* Licensed under the MIT License.\n", + "\n", + "#### PENALTY\n", + "\n", + "penalty term for constraints\n", + "\n", + "\n", + "---\n", + "sidebar_label: trial_scheduler\n", + "title: tune.scheduler.trial_scheduler\n", + "---\n", + "\n", + "## TrialScheduler Objects\n", + "\n", + "```python\n", + "class TrialScheduler()\n", + "```\n", + "\n", + "Interface for implementing a Trial Scheduler class.\n", + "\n", + "#### CONTINUE\n", + "\n", + "Status for continuing trial execution\n", + "\n", + "#### PAUSE\n", + "\n", + "Status for pausing trial execution\n", + "\n", + "#### STOP\n", + "\n", + "Status for stopping trial execution\n", + "\n", + "\n", + "---\n", + "sidebar_label: retrieve_user_proxy_agent\n", + "title: autogen.agentchat.contrib.retrieve_user_proxy_agent\n", + "---\n", + "\n", + "## RetrieveUserProxyAgent Objects\n", + "\n", + "```python\n", + "class RetrieveUserProxyAgent(UserProxyAgent)\n", + "```\n", + "\n", + "#### \\_\\_init\\_\\_\n", + "\n", + "```python\n", + "def __init__(name=\"RetrieveChatAgent\",\n", + " is_termination_msg: Optional[Callable[\n", + " [Dict], bool]] = _is_termination_msg_retrievechat,\n", + " human_input_mode: Optional[str] = \"ALWAYS\",\n", + " retrieve_config: Optional[Dict] = None,\n", + " **kwargs)\n", + "```\n", + "\n", + "**Arguments**:\n", + "\n", + "- `name` _str_ - name of the agent.\n", + "- `human_input_mode` _str_ - whether to ask for human inputs every time a message is received.\n", + " Possible values are \"ALWAYS\", \"TERMINATE\", \"NEVER\".\n", + " (1) When \"ALWAYS\", the agent prompts for human input every time a message is received.\n", + " Under this mode, the conversation stops when the human input is \"exit\",\n", + " or when is_termination_msg is True and there is no human input.\n", + " (2) When \"TERMINATE\", the agent only prompts for human input only when a termination message is received or\n", + " the number of auto reply reaches the max_consecutive_auto_reply.\n", + " (3) When \"NEVER\", the agent will never prompt for human input. Under this mode, the conversation stops\n", + " when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True.\n", + "- `retrieve_config` _dict or None_ - config for the retrieve agent.\n", + " To use default config, set to None. Otherwise, set to a dictionary with the following keys:\n", + " - task (Optional, str): the task of the retrieve chat. Possible values are \"code\", \"qa\" and \"default\". System\n", + " prompt will be different for different tasks. The default value is `default`, which supports both code and qa.\n", + " - client (Optional, chromadb.Client): the chromadb client.\n", + " If key not provided, a default client `chromadb.Client()` will be used.\n", + " - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,\n", + " or the url to a single file. If key not provided, a default path `./docs` will be used.\n", + " - collection_name (Optional, str): the name of the collection.\n", + " If key not provided, a default name `flaml-docs` will be used.\n", + " - model (Optional, str): the model to use for the retrieve chat.\n", + " If key not provided, a default model `gpt-4` will be used.\n", + " - chunk_token_size (Optional, int): the chunk token size for the retrieve chat.\n", + " If key not provided, a default size `max_tokens * 0.4` will be used.\n", + " - context_max_tokens (Optional, int): the context max token size for the retrieve chat.\n", + " If key not provided, a default size `max_tokens * 0.8` will be used.\n", + " - chunk_mode (Optional, str): the chunk mode for the retrieve chat. Possible values are\n", + " \"multi_lines\" and \"one_line\". If key not provided, a default mode `multi_lines` will be used.\n", + " - must_break_at_empty_line (Optional, bool): chunk will only break at empty line if True. Default is True.\n", + " If chunk_mode is \"one_line\", this parameter will be ignored.\n", + " - embedding_model (Optional, str): the embedding model to use for the retrieve chat.\n", + " If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models\n", + " can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a\n", + " fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended.\n", + " - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None.\n", + "- `**kwargs` _dict_ - other kwargs in [UserProxyAgent](user_proxy_agent#__init__).\n", + "\n", + "#### generate\\_init\\_message\n", + "\n", + "```python\n", + "def generate_init_message(problem: str,\n", + " n_results: int = 20,\n", + " search_string: str = \"\")\n", + "```\n", + "\n", + "Generate an initial message with the given problem and prompt.\n", + "\n", + "**Arguments**:\n", + "\n", + "- `problem` _str_ - the problem to be solved.\n", + "- `n_results` _int_ - the number of results to be retrieved.\n", + "- `search_string` _str_ - only docs containing this string will be retrieved.\n", + " \n", + "\n", + "**Returns**:\n", + "\n", + "- `str` - the generated prompt ready to be sent to the assistant agent.\n", + "\n", + "\n", + "---\n", + "sidebar_label: retrieve_assistant_agent\n", + "title: autogen.agentchat.contrib.retrieve_assistant_agent\n", + "---\n", + "\n", + "## RetrieveAssistantAgent Objects\n", + "\n", + "```python\n", + "class RetrieveAssistantAgent(AssistantAgent)\n", + "```\n", + "\n", + "(Experimental) Retrieve Assistant agent, designed to solve a task with LLM.\n", + "\n", + "RetrieveAssistantAgent is a subclass of AssistantAgent configured with a default system message.\n", + "The default system message is designed to solve a task with LLM,\n", + "including suggesting python code blocks and debugging.\n", + "`human_input_mode` is default to \"NEVER\"\n", + "and `code_execution_config` is default to False.\n", + "This agent doesn't execute code by default, and expects the user to execute the code.\n", + "\n", + "\n", + "---\n", + "sidebar_label: utils\n", + "title: automl.nlp.huggingface.utils\n", + "---\n", + "\n", + "#### todf\n", + "\n", + "```python\n", + "def todf(X, Y, column_name)\n", + "```\n", + "\n", + "todf converts Y from any format (list, pandas.Series, numpy array) to a DataFrame before being returned\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "No, there is no function called `tune_automl` in the given context.\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# reset the assistant. Always reset the assistant before starting a new conversation.\n", + "assistant.reset()\n", + "\n", + "qa_problem = \"Is there a function called tune_automl?\"\n", + "ragproxyagent.initiate_chat(assistant, problem=qa_problem)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Example 2\n", + "\n", + "[back to top](#toc)\n", + "\n", + "Use RetrieveChat to answer a question that is not related to code generation.\n", + "\n", + "Problem: Who is the author of FLAML?" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32mAdding doc_id 0 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 21 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 47 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 35 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 41 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 69 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 34 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 22 to context.\u001b[0m\n", + "\u001b[32mAdding doc_id 51 to context.\u001b[0m\n", + "\u001b[33mragproxyagent\u001b[0m (to assistant):\n", + "\n", + "You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n", + "context provided by the user.\n", + "If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n", + "For code generation, you must obey the following rules:\n", + "Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n", + "Rule 2. You must follow the formats below to write your code:\n", + "```language\n", + "# your code\n", + "```\n", + "\n", + "User's question is: Who is the author of FLAML?\n", + "\n", + "Context is: ---\n", + "sidebar_label: config\n", + "title: config\n", + "---\n", + "\n", + "!\n", + "* Copyright (c) Microsoft Corporation. All rights reserved.\n", + "* Licensed under the MIT License.\n", + "\n", + "#### PENALTY\n", + "\n", + "penalty term for constraints\n", + "\n", + "\n", + "---\n", + "sidebar_label: utils\n", + "title: automl.nlp.huggingface.utils\n", + "---\n", + "\n", + "#### todf\n", + "\n", + "```python\n", + "def todf(X, Y, column_name)\n", + "```\n", + "\n", + "todf converts Y from any format (list, pandas.Series, numpy array) to a DataFrame before being returned\n", + "\n", + "\n", + "---\n", + "sidebar_label: trial_scheduler\n", + "title: tune.scheduler.trial_scheduler\n", + "---\n", + "\n", + "## TrialScheduler Objects\n", + "\n", + "```python\n", + "class TrialScheduler()\n", + "```\n", + "\n", + "Interface for implementing a Trial Scheduler class.\n", + "\n", + "#### CONTINUE\n", + "\n", + "Status for continuing trial execution\n", + "\n", + "#### PAUSE\n", + "\n", + "Status for pausing trial execution\n", + "\n", + "#### STOP\n", + "\n", + "Status for stopping trial execution\n", + "\n", + "\n", + "---\n", + "sidebar_label: space\n", + "title: tune.space\n", + "---\n", + "\n", + "#### is\\_constant\n", + "\n", + "```python\n", + "def is_constant(space: Union[Dict, List]) -> bool\n", + "```\n", + "\n", + "Whether the search space is all constant.\n", + "\n", + "**Returns**:\n", + "\n", + " A bool of whether the search space is all constant.\n", + "\n", + "#### define\\_by\\_run\\_func\n", + "\n", + "```python\n", + "def define_by_run_func(trial,\n", + " space: Dict,\n", + " path: str = \"\") -> Optional[Dict[str, Any]]\n", + "```\n", + "\n", + "Define-by-run function to create the search space.\n", + "\n", + "**Returns**:\n", + "\n", + " A dict with constant values.\n", + "\n", + "#### unflatten\\_hierarchical\n", + "\n", + "```python\n", + "def unflatten_hierarchical(config: Dict, space: Dict) -> Tuple[Dict, Dict]\n", + "```\n", + "\n", + "Unflatten hierarchical config.\n", + "\n", + "#### add\\_cost\\_to\\_space\n", + "\n", + "```python\n", + "def add_cost_to_space(space: Dict, low_cost_point: Dict, choice_cost: Dict)\n", + "```\n", + "\n", + "Update the space in place by adding low_cost_point and choice_cost.\n", + "\n", + "**Returns**:\n", + "\n", + " A dict with constant values.\n", + "\n", + "#### normalize\n", + "\n", + "```python\n", + "def normalize(config: Dict,\n", + " space: Dict,\n", + " reference_config: Dict,\n", + " normalized_reference_config: Dict,\n", + " recursive: bool = False)\n", + "```\n", + "\n", + "Normalize config in space according to reference_config.\n", + "\n", + "Normalize each dimension in config to [0,1].\n", + "\n", + "#### indexof\n", + "\n", + "```python\n", + "def indexof(domain: Dict, config: Dict) -> int\n", + "```\n", + "\n", + "Find the index of config in domain.categories.\n", + "\n", + "#### complete\\_config\n", + "\n", + "```python\n", + "def complete_config(partial_config: Dict,\n", + " space: Dict,\n", + " flow2,\n", + " disturb: bool = False,\n", + " lower: Optional[Dict] = None,\n", + " upper: Optional[Dict] = None) -> Tuple[Dict, Dict]\n", + "```\n", + "\n", + "Complete partial config in space.\n", + "\n", + "**Returns**:\n", + "\n", + " config, space.\n", + "\n", + "\n", + "---\n", + "sidebar_label: search_thread\n", + "title: tune.searcher.search_thread\n", + "---\n", + "\n", + "## SearchThread Objects\n", + "\n", + "```python\n", + "class SearchThread()\n", + "```\n", + "\n", + "Class of global or local search thread.\n", + "\n", + "#### \\_\\_init\\_\\_\n", + "\n", + "```python\n", + "def __init__(mode: str = \"min\",\n", + " search_alg: Optional[Searcher] = None,\n", + " cost_attr: Optional[str] = TIME_TOTAL_S,\n", + " eps: Optional[float] = 1.0)\n", + "```\n", + "\n", + "When search_alg is omitted, use local search FLOW2.\n", + "\n", + "#### suggest\n", + "\n", + "```python\n", + "def suggest(trial_id: str) -> Optional[Dict]\n", + "```\n", + "\n", + "Use the suggest() of the underlying search algorithm.\n", + "\n", + "#### on\\_trial\\_complete\n", + "\n", + "```python\n", + "def on_trial_complete(trial_id: str,\n", + " result: Optional[Dict] = None,\n", + " error: bool = False)\n", + "```\n", + "\n", + "Update the statistics of the thread.\n", + "\n", + "#### reach\n", + "\n", + "```python\n", + "def reach(thread) -> bool\n", + "```\n", + "\n", + "Whether the incumbent can reach the incumbent of thread.\n", + "\n", + "#### can\\_suggest\n", + "\n", + "```python\n", + "@property\n", + "def can_suggest() -> bool\n", + "```\n", + "\n", + "Whether the thread can suggest new configs.\n", + "\n", + "\n", + "{\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " \"reference/autogen/agentchat/contrib/math_user_proxy_agent\",\n", + " \"reference/autogen/agentchat/contrib/retrieve_assistant_agent\",\n", + " \"reference/autogen/agentchat/contrib/retrieve_user_proxy_agent\"\n", + " ],\n", + " \"label\": \"autogen.agentchat.contrib\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/autogen/agentchat/agent\",\n", + " \"reference/autogen/agentchat/assistant_agent\",\n", + " \"reference/autogen/agentchat/conversable_agent\",\n", + " \"reference/autogen/agentchat/groupchat\",\n", + " \"reference/autogen/agentchat/user_proxy_agent\"\n", + " ],\n", + " \"label\": \"autogen.agentchat\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/autogen/oai/completion\",\n", + " \"reference/autogen/oai/openai_utils\"\n", + " ],\n", + " \"label\": \"autogen.oai\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/autogen/code_utils\",\n", + " \"reference/autogen/math_utils\",\n", + " \"reference/autogen/retrieve_utils\"\n", + " ],\n", + " \"label\": \"autogen\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/nlp/huggingface/trainer\",\n", + " \"reference/automl/nlp/huggingface/training_args\",\n", + " \"reference/automl/nlp/huggingface/utils\"\n", + " ],\n", + " \"label\": \"automl.nlp.huggingface\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/automl/nlp/utils\"\n", + " ],\n", + " \"label\": \"automl.nlp\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/spark/metrics\",\n", + " \"reference/automl/spark/utils\"\n", + " ],\n", + " \"label\": \"automl.spark\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/task/task\",\n", + " \"reference/automl/task/time_series_task\"\n", + " ],\n", + " \"label\": \"automl.task\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/automl/time_series/sklearn\",\n", + " \"reference/automl/time_series/tft\",\n", + " \"reference/automl/time_series/ts_data\",\n", + " \"reference/automl/time_series/ts_model\"\n", + " ],\n", + " \"label\": \"automl.time_series\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/automl/automl\",\n", + " \"reference/automl/data\",\n", + " \"reference/automl/ml\",\n", + " \"reference/automl/model\",\n", + " \"reference/automl/state\"\n", + " ],\n", + " \"label\": \"automl\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/default/estimator\",\n", + " \"reference/default/greedy\",\n", + " \"reference/default/portfolio\",\n", + " \"reference/default/suggest\"\n", + " ],\n", + " \"label\": \"default\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/onlineml/autovw\",\n", + " \"reference/onlineml/trial\",\n", + " \"reference/onlineml/trial_runner\"\n", + " ],\n", + " \"label\": \"onlineml\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " {\n", + " \"items\": [\n", + " \"reference/tune/scheduler/online_scheduler\",\n", + " \"reference/tune/scheduler/trial_scheduler\"\n", + " ],\n", + " \"label\": \"tune.scheduler\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/tune/searcher/blendsearch\",\n", + " \"reference/tune/searcher/cfo_cat\",\n", + " \"reference/tune/searcher/flow2\",\n", + " \"reference/tune/searcher/online_searcher\",\n", + " \"reference/tune/searcher/search_thread\",\n", + " \"reference/tune/searcher/suggestion\",\n", + " \"reference/tune/searcher/variant_generator\"\n", + " ],\n", + " \"label\": \"tune.searcher\",\n", + " \"type\": \"category\"\n", + " },\n", + " {\n", + " \"items\": [\n", + " \"reference/tune/spark/utils\"\n", + " ],\n", + " \"label\": \"tune.spark\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/tune/analysis\",\n", + " \"reference/tune/sample\",\n", + " \"reference/tune/space\",\n", + " \"reference/tune/trial\",\n", + " \"reference/tune/trial_runner\",\n", + " \"reference/tune/tune\",\n", + " \"reference/tune/utils\"\n", + " ],\n", + " \"label\": \"tune\",\n", + " \"type\": \"category\"\n", + " },\n", + " \"reference/config\"\n", + " ],\n", + " \"label\": \"Reference\",\n", + " \"type\": \"category\"\n", + "}\n", + "---\n", + "sidebar_label: utils\n", + "title: tune.utils\n", + "---\n", + "\n", + "#### choice\n", + "\n", + "```python\n", + "def choice(categories: Sequence, order=None)\n", + "```\n", + "\n", + "Sample a categorical value.\n", + "Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from\n", + "``np.random.choice([1, 2])``\n", + "\n", + "**Arguments**:\n", + "\n", + "- `categories` _Sequence_ - Sequence of categories to sample from.\n", + "- `order` _bool_ - Whether the categories have an order. If None, will be decided autoamtically:\n", + " Numerical categories have an order, while string categories do not.\n", + "\n", + "\n", + "---\n", + "sidebar_label: trainer\n", + "title: automl.nlp.huggingface.trainer\n", + "---\n", + "\n", + "## TrainerForAuto Objects\n", + "\n", + "```python\n", + "class TrainerForAuto(Seq2SeqTrainer)\n", + "```\n", + "\n", + "#### evaluate\n", + "\n", + "```python\n", + "def evaluate(eval_dataset=None, ignore_keys=None, metric_key_prefix=\"eval\")\n", + "```\n", + "\n", + "Overriding transformers.Trainer.evaluate by saving metrics and checkpoint path.\n", + "\n", + "\n", + "---\n", + "sidebar_label: trial\n", + "title: onlineml.trial\n", + "---\n", + "\n", + "#### get\\_ns\\_feature\\_dim\\_from\\_vw\\_example\n", + "\n", + "```python\n", + "def get_ns_feature_dim_from_vw_example(vw_example) -> dict\n", + "```\n", + "\n", + "Get a dictionary of feature dimensionality for each namespace singleton.\n", + "\n", + "## OnlineResult Objects\n", + "\n", + "```python\n", + "class OnlineResult()\n", + "```\n", + "\n", + "Class for managing the result statistics of a trial.\n", + "\n", + "#### CB\\_COEF\n", + "\n", + "0.001 for mse\n", + "\n", + "#### \\_\\_init\\_\\_\n", + "\n", + "```python\n", + "def __init__(result_type_name: str,\n", + " cb_coef: Optional[float] = None,\n", + " init_loss: Optional[float] = 0.0,\n", + " init_cb: Optional[float] = 100.0,\n", + " mode: Optional[str] = \"min\",\n", + " sliding_window_size: Optional[int] = 100)\n", + "```\n", + "\n", + "Constructor.\n", + "\n", + "**Arguments**:\n", + "\n", + "- `result_type_name` - A String to specify the name of the result type.\n", + "- `cb_coef` - a string to specify the coefficient on the confidence bound.\n", + "- `init_loss` - a float to specify the inital loss.\n", + "- `init_cb` - a float to specify the intial confidence bound.\n", + "- `mode` - A string in ['min', 'max'] to specify the objective as\n", + " minimization or maximization.\n", + "- `sliding_window_size` - An int to specify the size of the sliding window\n", + " (for experimental purpose).\n", + "\n", + "#### update\\_result\n", + "\n", + "```python\n", + "def update_result(new_loss,\n", + " new_resource_used,\n", + " data_dimension,\n", + " bound_of_range=1.0,\n", + " new_observation_count=1.0)\n", + "```\n", + "\n", + "Update result statistics.\n", + "\n", + "## BaseOnlineTrial Objects\n", + "\n", + "```python\n", + "class BaseOnlineTrial(Trial)\n", + "```\n", + "\n", + "Class for the online trial.\n", + "\n", + "#### \\_\\_init\\_\\_\n", + "\n", + "```python\n", + "def __init__(config: dict,\n", + " min_resource_lease: float,\n", + " is_champion: Optional[bool] = False,\n", + " is_checked_under_current_champion: Optional[bool] = True,\n", + " custom_trial_name: Optional[str] = \"mae\",\n", + " trial_id: Optional[str] = None)\n", + "```\n", + "\n", + "Constructor.\n", + "\n", + "**Arguments**:\n", + "\n", + "- `config` - The configuration dictionary.\n", + "- `min_resource_lease` - A float specifying the minimum resource lease.\n", + "- `is_champion` - A bool variable indicating whether the trial is champion.\n", + "- `is_checked_under_current_champion` - A bool indicating whether the trial\n", + " has been used under the current champion.\n", + "- `custom_trial_name` - A string of a custom trial name.\n", + "- `trial_id` - A string for the trial id.\n", + "\n", + "#### set\\_resource\\_lease\n", + "\n", + "```python\n", + "def set_resource_lease(resource: float)\n", + "```\n", + "\n", + "Sets the resource lease accordingly.\n", + "\n", + "#### set\\_status\n", + "\n", + "```python\n", + "def set_status(status)\n", + "```\n", + "\n", + "Sets the status of the trial and record the start time.\n", + "\n", + "## VowpalWabbitTrial Objects\n", + "\n", + "```python\n", + "class VowpalWabbitTrial(BaseOnlineTrial)\n", + "```\n", + "\n", + "The class for Vowpal Wabbit online trials.\n", + "\n", + "#### \\_\\_init\\_\\_\n", + "\n", + "```python\n", + "def __init__(config: dict,\n", + " min_resource_lease: float,\n", + " metric: str = \"mae\",\n", + " is_champion: Optional[bool] = False,\n", + " is_checked_under_current_champion: Optional[bool] = True,\n", + " custom_trial_name: Optional[str] = \"vw_mae_clipped\",\n", + " trial_id: Optional[str] = None,\n", + " cb_coef: Optional[float] = None)\n", + "```\n", + "\n", + "Constructor.\n", + "\n", + "**Arguments**:\n", + "\n", + "- `config` _dict_ - the config of the trial (note that the config is a set\n", + " because the hyperparameters are).\n", + "- `min_resource_lease` _float_ - the minimum resource lease.\n", + "- `metric` _str_ - the loss metric.\n", + "- `is_champion` _bool_ - indicates whether the trial is the current champion or not.\n", + "- `is_checked_under_current_champion` _bool_ - indicates whether this trials has\n", + " been paused under the current champion.\n", + "- `trial_id` _str_ - id of the trial (if None, it will be generated in the constructor).\n", + "\n", + "#### train\\_eval\\_model\\_online\n", + "\n", + "```python\n", + "def train_eval_model_online(data_sample, y_pred)\n", + "```\n", + "\n", + "Train and evaluate model online.\n", + "\n", + "#### predict\n", + "\n", + "```python\n", + "def predict(x)\n", + "```\n", + "\n", + "Predict using the model.\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33massistant\u001b[0m (to ragproxyagent):\n", + "\n", + "The author of FLAML is Microsoft Corporation.\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + } + ], + "source": [ + "# reset the assistant. Always reset the assistant before starting a new conversation.\n", + "assistant.reset()\n", + "\n", + "qa_problem = \"Who is the author of FLAML?\"\n", + "ragproxyagent.initiate_chat(assistant, problem=qa_problem)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/test/agentchat/test_qdrant_retrievechat.py b/test/agentchat/test_qdrant_retrievechat.py new file mode 100644 index 000000000..9600b507e --- /dev/null +++ b/test/agentchat/test_qdrant_retrievechat.py @@ -0,0 +1,102 @@ +import os + +import pytest + +from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent +from autogen import ChatCompletion, config_list_from_json +from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST + +try: + from qdrant_client import QdrantClient + from autogen.agentchat.contrib.qdrant_retrieve_user_proxy_agent import ( + create_qdrant_from_dir, + QdrantRetrieveUserProxyAgent, + query_qdrant, + ) + import fastembed + + QDRANT_INSTALLED = True +except ImportError: + QDRANT_INSTALLED = False + +test_dir = os.path.join(os.path.dirname(__file__), "..", "test_files") + + +@pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed") +def test_retrievechat(): + try: + import openai + except ImportError: + return + + conversations = {} + ChatCompletion.start_logging(conversations) + + config_list = config_list_from_json( + OAI_CONFIG_LIST, + file_location=KEY_LOC, + filter_dict={ + "model": ["gpt-4", "gpt4", "gpt-4-32k", "gpt-4-32k-0314"], + }, + ) + + assistant = RetrieveAssistantAgent( + name="assistant", + system_message="You are a helpful assistant.", + llm_config={ + "request_timeout": 600, + "seed": 42, + "config_list": config_list, + }, + ) + + client = QdrantClient(":memory:") + ragproxyagent = QdrantRetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "client": client, + "docs_path": "./website/docs", + "chunk_token_size": 2000, + }, + ) + + assistant.reset() + + code_problem = "How can I use FLAML to perform a classification task, set use_spark=True, train 30 seconds and force cancel jobs if time limit is reached." + ragproxyagent.initiate_chat(assistant, problem=code_problem, silent=True) + print(conversations) + + +@pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed") +def test_qdrant_filter(): + client = QdrantClient(":memory:") + create_qdrant_from_dir(dir_path="./website/docs", client=client, collection_name="autogen-docs") + results = query_qdrant( + query_texts=["How can I use AutoGen UserProxyAgent and AssistantAgent to do code generation?"], + n_results=4, + client=client, + collection_name="autogen-docs", + # Return only documents with "AutoGen" in the string + search_string="AutoGen", + ) + assert len(results["ids"][0]) == 4 + + +@pytest.mark.skipif(not QDRANT_INSTALLED, reason="qdrant_client is not installed") +def test_qdrant_search(): + client = QdrantClient(":memory:") + create_qdrant_from_dir(test_dir, client=client) + + assert client.get_collection("all-my-documents") + + # Perform a semantic search without any filter + results = query_qdrant(["autogen"], client=client) + assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) + + +if __name__ == "__main__": + test_retrievechat() + test_qdrant_filter() + test_qdrant_search() diff --git a/website/docs/Examples/AutoGen-AgentChat.md b/website/docs/Examples/AutoGen-AgentChat.md index 9d29081ca..a9a813ae6 100644 --- a/website/docs/Examples/AutoGen-AgentChat.md +++ b/website/docs/Examples/AutoGen-AgentChat.md @@ -5,11 +5,13 @@ Please find documentation about this feature [here](/docs/Use-Cases/agent_chat). Links to notebook examples: + 1. **Code Generation, Execution, and Debugging** - Automated Task Solving with Code Generation, Execution & Debugging - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb) - Auto Code Generation, Execution, Debugging and Human Feedback - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_human_feedback.ipynb) - Automated Code Generation and Question Answering with Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_RetrieveChat.ipynb) + - Automated Code Generation and Question Answering with [Qdrant](https://qdrant.tech/) based Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_qdrant_RetrieveChat.ipynb) 2. **Multi-Agent Collaboration (>3 Agents)**