Expose more Task-Centric Memory parameters (#6246)

<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

- Exposes a few optional memory controller parameters for more detailed
control and evaluation.
- Fixes a couple formatting issues in the documentation.

## Related issue number

None

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.
This commit is contained in:
Ricky Loynd 2025-04-08 15:13:34 -07:00 committed by GitHub
parent ac315ef3ce
commit b3f59057fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 74 additions and 36 deletions

View File

@ -184,7 +184,7 @@ class Prompter:
return topic_list
async def generalize_task(self, task_description: str) -> str:
async def generalize_task(self, task_description: str, revise: bool | None = True) -> str:
"""
Attempts to rewrite a task description in a more general form.
"""
@ -198,29 +198,31 @@ class Prompter:
user_message.append(task_description)
self._clear_history()
await self.call_model(
generalized_task = await self.call_model(
summary="Ask the model to rephrase the task in a list of important points",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
]
await self.call_model(
summary="Ask the model to identify irrelevant points",
system_message_content=sys_message,
user_content=user_message,
)
if revise:
user_message = [
"Do you see any parts of this list that are irrelevant to actually solving the task? If so, explain which items are irrelevant."
]
await self.call_model(
summary="Ask the model to identify irrelevant points",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
]
generalized_task = await self.call_model(
summary="Ask the model to make a final list of general terms",
system_message_content=sys_message,
user_content=user_message,
)
user_message = [
"Revise your original list to include only the most general terms, those that are critical to solving the task, removing any themes or descriptions that are not essential to the solution. Your final list may be shorter, but do not leave out any part of the task that is needed for solving the task. Do not add any additional commentary either before or after the list."
]
generalized_task = await self.call_model(
summary="Ask the model to make a final list of general terms",
system_message_content=sys_message,
user_content=user_message,
)
return generalized_task
async def validate_insight(self, insight: str, task_description: str) -> bool:

View File

@ -16,6 +16,11 @@ from .utils.page_logger import PageLogger
# Following the nested-config pattern, this TypedDict minimizes code changes by encapsulating
# the settings that change frequently, as when loading many settings from a single YAML file.
class MemoryControllerConfig(TypedDict, total=False):
generalize_task: bool
revise_generalized_task: bool
generate_topics: bool
validate_memos: bool
max_memos_to_retrieve: int
max_train_trials: int
max_test_trials: int
MemoryBank: "MemoryBankConfig"
@ -33,6 +38,11 @@ class MemoryController:
task_assignment_callback: An optional callback used to assign a task to any agent managed by the caller.
config: An optional dict that can be used to override the following values:
- generalize_task: Whether to rewrite tasks in more general terms.
- revise_generalized_task: Whether to critique then rewrite the generalized task.
- generate_topics: Whether to base retrieval directly on tasks, or on topics extracted from tasks.
- validate_memos: Whether to apply a final validation stage to retrieved memos.
- max_memos_to_retrieve: The maximum number of memos to return from retrieve_relevant_memos().
- max_train_trials: The maximum number of learning iterations to attempt when training on a task.
- max_test_trials: The total number of attempts made when testing for failure on a task.
- MemoryBank: A config dict passed to MemoryBank.
@ -91,10 +101,20 @@ class MemoryController:
self.logger.enter_function()
# Apply default settings and any config overrides.
self.generalize_task = True
self.revise_generalized_task = True
self.generate_topics = True
self.validate_memos = True
self.max_memos_to_retrieve = 10
self.max_train_trials = 10
self.max_test_trials = 3
memory_bank_config = None
if config is not None:
self.generalize_task = config.get("generalize_task", self.generalize_task)
self.revise_generalized_task = config.get("revise_generalized_task", self.revise_generalized_task)
self.generate_topics = config.get("generate_topics", self.generate_topics)
self.validate_memos = config.get("validate_memos", self.validate_memos)
self.max_memos_to_retrieve = config.get("max_memos_to_retrieve", self.max_memos_to_retrieve)
self.max_train_trials = config.get("max_train_trials", self.max_train_trials)
self.max_test_trials = config.get("max_test_trials", self.max_test_trials)
memory_bank_config = config.get("MemoryBank", memory_bank_config)
@ -178,8 +198,10 @@ class MemoryController:
if task is not None:
self.logger.info("\nGIVEN TASK:")
self.logger.info(task)
# Generalize the task.
generalized_task = await self.prompter.generalize_task(task)
if self.generalize_task:
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
else:
generalized_task = task
self.logger.info("\nGIVEN INSIGHT:")
self.logger.info(insight)
@ -196,7 +218,10 @@ class MemoryController:
text_to_index = task
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
topics = await self.prompter.find_index_topics(text_to_index)
if self.generate_topics:
topics = await self.prompter.find_index_topics(text_to_index)
else:
topics = [text_to_index]
self.logger.info("\n".join(topics))
self.logger.info("")
@ -218,7 +243,10 @@ class MemoryController:
self.logger.info(solution)
# Get a list of topics from the task.
topics = await self.prompter.find_index_topics(task.strip())
if self.generate_topics:
topics = await self.prompter.find_index_topics(task.strip())
else:
topics = [task.strip()]
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
self.logger.info("\n".join(topics))
self.logger.info("")
@ -238,8 +266,14 @@ class MemoryController:
self.logger.info(task)
# Get a list of topics from the generalized task.
generalized_task = await self.prompter.generalize_task(task)
task_topics = await self.prompter.find_index_topics(generalized_task)
if self.generalize_task:
generalized_task = await self.prompter.generalize_task(task, revise=self.revise_generalized_task)
else:
generalized_task = task
if self.generate_topics:
task_topics = await self.prompter.find_index_topics(generalized_task)
else:
task_topics = [generalized_task]
self.logger.info("\nTOPICS EXTRACTED FROM TASK:")
self.logger.info("\n".join(task_topics))
self.logger.info("")
@ -250,7 +284,9 @@ class MemoryController:
# Apply a final validation stage to keep only the memos that the LLM concludes are sufficiently relevant.
validated_memos: List[Memo] = []
for memo in memo_list:
if await self.prompter.validate_insight(memo.insight, task):
if len(validated_memos) >= self.max_memos_to_retrieve:
break
if (not self.validate_memos) or await self.prompter.validate_insight(memo.insight, task):
validated_memos.append(memo)
self.logger.info("\n{} VALIDATED MEMOS".format(len(validated_memos)))

View File

@ -41,10 +41,9 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
create calls) or a "stream" (a list of streamed outputs for create_stream calls).
ReplayChatCompletionClient and ChatCompletionCache do similar things, but with significant differences:
- ReplayChatCompletionClient replays pre-defined responses in a specified order
without recording anything or checking the messages sent to the client.
- ChatCompletionCache caches responses and replays them for messages that have been seen before,
regardless of order, and calls the base client for any uncached messages.
- ReplayChatCompletionClient replays pre-defined responses in a specified order without recording anything or checking the messages sent to the client.
- ChatCompletionCache caches responses and replays them for messages that have been seen before, regardless of order, and calls the base client for any uncached messages.
"""
def __init__(

View File

@ -14,10 +14,11 @@ class Teachability(Memory):
Gives an AssistantAgent the ability to learn quickly from user teachings, hints, and advice.
Steps for usage:
1. Instantiate MemoryController.
2. Instantiate Teachability, passing the memory controller as a parameter.
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
4. Use the AssistantAgent as usual, such as for chatting with the user.
1. Instantiate MemoryController.
2. Instantiate Teachability, passing the memory controller as a parameter.
3. Instantiate an AssistantAgent, passing the teachability instance (wrapped in a list) as the memory parameter.
4. Use the AssistantAgent as usual, such as for chatting with the user.
"""
def __init__(self, memory_controller: "MemoryController", name: str | None = None) -> None:

View File

@ -41,7 +41,7 @@ or else modify `utils/client.py` as appropriate for the model you choose.
## Running the Samples
The following samples are listed in order of increasing complexity.
Execute the corresponding commands from this (autogen_ext/task_centric_memory) directory.
Execute the corresponding commands from the `python/samples/task_centric_memory` directory.
### Making AssistantAgent Teachable

View File

@ -15,10 +15,10 @@ client:
Apprentice:
name_of_agent_or_team: AssistantAgent # AssistantAgent or MagenticOneGroupChat
disable_prefix_caching: 1 # If true, prepends a small random string to the context, to decorrelate repeated runs.
TaskCentricMemoryController:
MemoryController:
max_train_trials: 10
max_test_trials: 3
TaskCentricMemoryBank:
MemoryBank:
path: ./memory_bank/self_teaching
relevance_conversion_threshold: 1.7
n_results: 25