mirror of https://github.com/microsoft/autogen.git
Expose more Task-Centric Memory parameters (#6246)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? - Exposes a few optional memory controller parameters for more detailed control and evaluation. - Fixes a couple formatting issues in the documentation. ## Related issue number None ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed.
This commit is contained in:
parent
ac315ef3ce
commit
b3f59057fa
|
@ -184,7 +184,7 @@ class Prompter:
|
|||
|
||||
return topic_list
|
||||
|
||||
async def generalize_task(self, task_description: str) -> str:
|
||||
async def generalize_task(self, task_description: str, revise: bool | None = True) -> str:
|
||||
"""
|
||||
Attempts to rewrite a task description in a more general form.
|
||||
"""
|
||||
|
@ -198,29 +198,31 @@ class Prompter:
|
|||
user_message.append(task_description)
|
||||
|
||||
self._clear_history()
|
||||
await self.call_model(
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to rephrase the task in a list of important points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
|
||||
]
|
||||
await self.call_model(
|
||||
summary="Ask the model to identify irrelevant points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
if revise:
|
||||
user_message = [
|
||||
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
|
||||
]
|
||||
await self.call_model(
|
||||
summary="Ask the model to identify irrelevant points",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
|
||||
]
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to make a final list of general terms",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
|
||||
user_message = [
|
||||
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
|
||||
]
|
||||
generalized_task = await self.call_model(
|
||||
summary="Ask the model to make a final list of general terms",
|
||||
system_message_content=sys_message,
|
||||
user_content=user_message,
|
||||
)
|
||||
return generalized_task
|
||||
|
||||
async def validate_insight(self, insight: str, task_description: str) -> bool:
|
||||
|
|
|
@ -16,6 +16,11 @@ from .utils.page_logger import PageLogger
|
|||
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
|
||||
# the settings that change frequently, as when loading many settings from a single YAML file.
|
||||
class MemoryControllerConfig(TypedDict, total=False):
|
||||
generalize_task: bool
|
||||
revise_generalized_task: bool
|
||||
generate_topics: bool
|
||||
validate_memos: bool
|
||||
max_memos_to_retrieve: int
|
||||
max_train_trials: int
|
||||
max_test_trials: int
|
||||
MemoryBank: "MemoryBankConfig"
|
||||
|
@ -33,6 +38,11 @@ class MemoryController:
|
|||
task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller.
|
||||
config: An optional dict that can be used to override the following values:
|
||||
|
||||
- generalize_task: Whether to rewrite tasks in more general terms.
|
||||
- revise_generalized_task: Whether to critique then rewrite the generalized task.
|
||||
- generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks.
|
||||
- validate_memos: Whether to apply a final validation stage to retrieved memos.
|
||||
- max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos().
|
||||
- max_train_trials: The maximum number of learning iterations to attempt when training on a task.
|
||||
- max_test_trials: The total number of attempts made when testing for failure on a task.
|
||||
- MemoryBank: A config dict passed to MemoryBank.
|
||||
|
@ -91,10 +101,20 @@ class MemoryController:
|
|||
self.logger.enter_function()
|
||||
|
||||
# Apply default settings and any config overrides.
|
||||
self.generalize_task = True
|
||||
self.revise_generalized_task = True
|
||||
self.generate_topics = True
|
||||
self.validate_memos = True
|
||||
self.max_memos_to_retrieve = 10
|
||||
self.max_train_trials = 10
|
||||
self.max_test_trials = 3
|
||||
memory_bank_config = None
|
||||
if config is not None:
|
||||
self.generalize_task = config.get("generalize_task", self.generalize_task)
|
||||
self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task)
|
||||
self.generate_topics = config.get("generate_topics", self.generate_topics)
|
||||
self.validate_memos = config.get("validate_memos", self.validate_memos)
|
||||
self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve)
|
||||
self.max_train_trials = config.get("max_train_trials", self.max_train_trials)
|
||||
self.max_test_trials = config.get("max_test_trials", self.max_test_trials)
|
||||
memory_bank_config = config.get("MemoryBank", memory_bank_config)
|
||||
|
@ -178,8 +198,10 @@ class MemoryController:
|
|||
if task is not None:
|
||||
self.logger.info("\nGIVEN TASK:")
|
||||
self.logger.info(task)
|
||||
# Generalize the task.
|
||||
generalized_task = await self.prompter.generalize_task(task)
|
||||
if self.generalize_task:
|
||||
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
|
||||
else:
|
||||
generalized_task = task
|
||||
|
||||
self.logger.info("\nGIVEN INSIGHT:")
|
||||
self.logger.info(insight)
|
||||
|
@ -196,7 +218,10 @@ class MemoryController:
|
|||
text_to_index = task
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
|
||||
topics = await self.prompter.find_index_topics(text_to_index)
|
||||
if self.generate_topics:
|
||||
topics = await self.prompter.find_index_topics(text_to_index)
|
||||
else:
|
||||
topics = [text_to_index]
|
||||
self.logger.info("\n".join(topics))
|
||||
self.logger.info("")
|
||||
|
||||
|
@ -218,7 +243,10 @@ class MemoryController:
|
|||
self.logger.info(solution)
|
||||
|
||||
# Get a list of topics from the task.
|
||||
topics = await self.prompter.find_index_topics(task.strip())
|
||||
if self.generate_topics:
|
||||
topics = await self.prompter.find_index_topics(task.strip())
|
||||
else:
|
||||
topics = [task.strip()]
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
self.logger.info("\n".join(topics))
|
||||
self.logger.info("")
|
||||
|
@ -238,8 +266,14 @@ class MemoryController:
|
|||
self.logger.info(task)
|
||||
|
||||
# Get a list of topics from the generalized task.
|
||||
generalized_task = await self.prompter.generalize_task(task)
|
||||
task_topics = await self.prompter.find_index_topics(generalized_task)
|
||||
if self.generalize_task:
|
||||
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
|
||||
else:
|
||||
generalized_task = task
|
||||
if self.generate_topics:
|
||||
task_topics = await self.prompter.find_index_topics(generalized_task)
|
||||
else:
|
||||
task_topics = [generalized_task]
|
||||
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
|
||||
self.logger.info("\n".join(task_topics))
|
||||
self.logger.info("")
|
||||
|
@ -250,7 +284,9 @@ class MemoryController:
|
|||
# Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant.
|
||||
validated_memos: List[Memo] = []
|
||||
for memo in memo_list:
|
||||
if await self.prompter.validate_insight(memo.insight, task):
|
||||
if len(validated_memos) >= self.max_memos_to_retrieve:
|
||||
break
|
||||
if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task):
|
||||
validated_memos.append(memo)
|
||||
|
||||
self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos)))
|
||||
|
|
|
@ -41,10 +41,9 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
|||
create calls) or a "stream" (a list of streamed outputs for create_stream calls).
|
||||
|
||||
ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences:
|
||||
- ReplayChatCompletionClient replays pre-defined responses in a specified order
|
||||
without recording anything or checking the messages sent to the client.
|
||||
- ChatCompletionCache caches responses and replays them for messages that have been seen before,
|
||||
regardless of order, and calls the base client for any uncached messages.
|
||||
|
||||
- ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client.
|
||||
- ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -14,10 +14,11 @@ class Teachability(Memory):
|
|||
Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice.
|
||||
|
||||
Steps for usage:
|
||||
1. Instantiate MemoryController.
|
||||
2. Instantiate Teachability, passing the memory controller as a parameter.
|
||||
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
|
||||
4. Use the AssistantAgent as usual, such as for chatting with the user.
|
||||
|
||||
1. Instantiate MemoryController.
|
||||
2. Instantiate Teachability, passing the memory controller as a parameter.
|
||||
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
|
||||
4. Use the AssistantAgent as usual, such as for chatting with the user.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None:
|
||||
|
|
|
@ -41,7 +41,7 @@ or else modify `utils/client.py` as appropriate for the model you choose.
|
|||
## Running the Samples
|
||||
|
||||
The following samples are listed in order of increasing complexity.
|
||||
Execute the corresponding commands from this (autogen_ext/task_centric_memory) directory.
|
||||
Execute the corresponding commands from the `python/samples/task_centric_memory` directory.
|
||||
|
||||
|
||||
### Making AssistantAgent Teachable
|
||||
|
|
|
@ -15,10 +15,10 @@ client:
|
|||
Apprentice:
|
||||
name_of_agent_or_team: AssistantAgent # AssistantAgent or MagenticOneGroupChat
|
||||
disable_prefix_caching: 1 # If true, prepends a small random string to the context, to decorrelate repeated runs.
|
||||
TaskCentricMemoryController:
|
||||
MemoryController:
|
||||
max_train_trials: 10
|
||||
max_test_trials: 3
|
||||
TaskCentricMemoryBank:
|
||||
MemoryBank:
|
||||
path: ./memory_bank/self_teaching
|
||||
relevance_conversion_threshold: 1.7
|
||||
n_results: 25
|
||||
|
|
Loading…
Reference in New Issue