mirror of https://github.com/microsoft/autogen.git
504 lines
21 KiB
Python
504 lines
21 KiB
Python
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, List, Literal, Optional, Union
|
|
|
|
import aiofiles
|
|
import yaml
|
|
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
|
|
from autogen_agentchat.conditions import (
|
|
ExternalTermination,
|
|
HandoffTermination,
|
|
MaxMessageTermination,
|
|
SourceMatchTermination,
|
|
StopMessageTermination,
|
|
TextMentionTermination,
|
|
TimeoutTermination,
|
|
TokenUsageTermination,
|
|
)
|
|
from autogen_agentchat.teams import MagenticOneGroupChat, RoundRobinGroupChat, SelectorGroupChat
|
|
from autogen_core.tools import FunctionTool
|
|
from autogen_ext.agents.file_surfer import FileSurfer
|
|
from autogen_ext.agents.magentic_one import MagenticOneCoderAgent
|
|
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
|
|
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
|
|
|
from ..datamodel.types import (
|
|
AgentConfig,
|
|
AgentTypes,
|
|
AssistantAgentConfig,
|
|
AzureOpenAIModelConfig,
|
|
CombinationTerminationConfig,
|
|
ComponentConfig,
|
|
ComponentConfigInput,
|
|
ComponentTypes,
|
|
MagenticOneTeamConfig,
|
|
MaxMessageTerminationConfig,
|
|
ModelConfig,
|
|
ModelTypes,
|
|
MultimodalWebSurferAgentConfig,
|
|
OpenAIModelConfig,
|
|
RoundRobinTeamConfig,
|
|
SelectorTeamConfig,
|
|
TeamConfig,
|
|
TeamTypes,
|
|
TerminationConfig,
|
|
TerminationTypes,
|
|
TextMentionTerminationConfig,
|
|
ToolConfig,
|
|
ToolTypes,
|
|
UserProxyAgentConfig,
|
|
)
|
|
from ..utils.utils import Version
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat]
|
|
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer, UserProxyAgent, FileSurfer, MagenticOneCoderAgent]
|
|
ModelComponent = Union[OpenAIChatCompletionClient, AzureOpenAIChatCompletionClient]
|
|
ToolComponent = Union[FunctionTool] # Will grow with more tool types
|
|
TerminationComponent = Union[
|
|
MaxMessageTermination,
|
|
StopMessageTermination,
|
|
TextMentionTermination,
|
|
TimeoutTermination,
|
|
ExternalTermination,
|
|
TokenUsageTermination,
|
|
HandoffTermination,
|
|
SourceMatchTermination,
|
|
StopMessageTermination,
|
|
]
|
|
|
|
Component = Union[TeamComponent, AgentComponent, ModelComponent, ToolComponent, TerminationComponent]
|
|
|
|
ReturnType = Literal["object", "dict", "config"]
|
|
|
|
DEFAULT_SELECTOR_PROMPT = """You are in a role play game. The following roles are available:
|
|
{roles}.
|
|
Read the following conversation. Then select the next role from {participants} to play. Only return the role.
|
|
|
|
{history}
|
|
|
|
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
|
"""
|
|
|
|
CONFIG_RETURN_TYPES = Literal["object", "dict", "config"]
|
|
|
|
|
|
class ComponentFactory:
|
|
"""Creates and manages agent components with versioned configuration loading"""
|
|
|
|
SUPPORTED_VERSIONS = {
|
|
ComponentTypes.TEAM: ["1.0.0"],
|
|
ComponentTypes.AGENT: ["1.0.0"],
|
|
ComponentTypes.MODEL: ["1.0.0"],
|
|
ComponentTypes.TOOL: ["1.0.0"],
|
|
ComponentTypes.TERMINATION: ["1.0.0"],
|
|
}
|
|
|
|
def __init__(self):
|
|
self._model_cache: Dict[str, ModelComponent] = {}
|
|
self._tool_cache: Dict[str, FunctionTool] = {}
|
|
self._last_cache_clear = datetime.now()
|
|
|
|
async def load(
|
|
self, component: ComponentConfigInput, input_func: Optional[Callable] = None, return_type: ReturnType = "object"
|
|
) -> Union[Component, dict, ComponentConfig]:
|
|
"""
|
|
Universal loader for any component type
|
|
|
|
Args:
|
|
component: Component configuration (file path, dict, or ComponentConfig)
|
|
input_func: Optional callable for user input handling
|
|
return_type: Type of return value ('object', 'dict', or 'config')
|
|
|
|
Returns:
|
|
Component instance, config dict, or ComponentConfig based on return_type
|
|
"""
|
|
try:
|
|
# Load and validate config
|
|
if isinstance(component, (str, Path)):
|
|
component_dict = await self._load_from_file(component)
|
|
config = self._dict_to_config(component_dict)
|
|
elif isinstance(component, dict):
|
|
config = self._dict_to_config(component)
|
|
else:
|
|
config = component
|
|
|
|
# Validate version
|
|
if not self._is_version_supported(config.component_type, config.version):
|
|
raise ValueError(
|
|
f"Unsupported version {config.version} for "
|
|
f"component type {config.component_type}. "
|
|
f"Supported versions: {self.SUPPORTED_VERSIONS[config.component_type]}"
|
|
)
|
|
|
|
# Return early if dict or config requested
|
|
if return_type == "dict":
|
|
return config.model_dump()
|
|
elif return_type == "config":
|
|
return config
|
|
|
|
# Otherwise create and return component instance
|
|
handlers = {
|
|
ComponentTypes.TEAM: lambda c: self.load_team(c, input_func),
|
|
ComponentTypes.AGENT: lambda c: self.load_agent(c, input_func),
|
|
ComponentTypes.MODEL: self.load_model,
|
|
ComponentTypes.TOOL: self.load_tool,
|
|
ComponentTypes.TERMINATION: self.load_termination,
|
|
}
|
|
|
|
handler = handlers.get(config.component_type)
|
|
if not handler:
|
|
raise ValueError(f"Unknown component type: {config.component_type}")
|
|
|
|
return await handler(config)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load component: {str(e)}")
|
|
raise
|
|
|
|
async def load_directory(
|
|
self, directory: Union[str, Path], return_type: ReturnType = "object"
|
|
) -> List[Union[Component, dict, ComponentConfig]]:
|
|
"""
|
|
Import all component configurations from a directory.
|
|
"""
|
|
components = []
|
|
try:
|
|
directory = Path(directory)
|
|
# Using Path.iterdir() instead of os.listdir
|
|
for path in list(directory.glob("*")):
|
|
if path.suffix.lower().endswith((".json", ".yaml", ".yml")):
|
|
try:
|
|
component = await self.load(path, return_type=return_type)
|
|
components.append(component)
|
|
except Exception as e:
|
|
logger.info(f"Failed to load component: {str(e)}, {path}")
|
|
|
|
return components
|
|
except Exception as e:
|
|
logger.info(f"Failed to load directory: {str(e)}")
|
|
return components
|
|
|
|
def _dict_to_config(self, config_dict: dict) -> ComponentConfig:
|
|
"""Convert dictionary to appropriate config type based on component_type and type discriminator"""
|
|
if "component_type" not in config_dict:
|
|
raise ValueError("component_type is required in configuration")
|
|
|
|
component_type = ComponentTypes(config_dict["component_type"])
|
|
|
|
# Define mapping structure
|
|
type_mappings = {
|
|
ComponentTypes.MODEL: {
|
|
"discriminator": "model_type",
|
|
ModelTypes.OPENAI.value: OpenAIModelConfig,
|
|
ModelTypes.AZUREOPENAI.value: AzureOpenAIModelConfig,
|
|
},
|
|
ComponentTypes.AGENT: {
|
|
"discriminator": "agent_type",
|
|
AgentTypes.ASSISTANT.value: AssistantAgentConfig,
|
|
AgentTypes.USERPROXY.value: UserProxyAgentConfig,
|
|
AgentTypes.MULTIMODAL_WEBSURFER.value: MultimodalWebSurferAgentConfig,
|
|
},
|
|
ComponentTypes.TEAM: {
|
|
"discriminator": "team_type",
|
|
TeamTypes.ROUND_ROBIN.value: RoundRobinTeamConfig,
|
|
TeamTypes.SELECTOR.value: SelectorTeamConfig,
|
|
TeamTypes.MAGENTIC_ONE.value: MagenticOneTeamConfig,
|
|
},
|
|
ComponentTypes.TOOL: ToolConfig,
|
|
ComponentTypes.TERMINATION: {
|
|
"discriminator": "termination_type",
|
|
TerminationTypes.MAX_MESSAGES.value: MaxMessageTerminationConfig,
|
|
TerminationTypes.TEXT_MENTION.value: TextMentionTerminationConfig,
|
|
TerminationTypes.COMBINATION.value: CombinationTerminationConfig,
|
|
},
|
|
}
|
|
|
|
mapping = type_mappings.get(component_type)
|
|
if not mapping:
|
|
raise ValueError(f"Unknown component type: {component_type}")
|
|
|
|
# Handle simple cases (no discriminator)
|
|
if isinstance(mapping, type):
|
|
return mapping(**config_dict)
|
|
|
|
# Get discriminator field value
|
|
discriminator = mapping["discriminator"]
|
|
if discriminator not in config_dict:
|
|
raise ValueError(f"Missing {discriminator} in configuration")
|
|
|
|
type_value = config_dict[discriminator]
|
|
config_class = mapping.get(type_value)
|
|
|
|
if not config_class:
|
|
raise ValueError(f"Unknown {discriminator}: {type_value}")
|
|
|
|
return config_class(**config_dict)
|
|
|
|
async def load_termination(self, config: TerminationConfig) -> TerminationComponent:
|
|
"""Create termination condition instance from configuration."""
|
|
try:
|
|
if config.termination_type == TerminationTypes.COMBINATION:
|
|
if not config.conditions or len(config.conditions) < 2:
|
|
raise ValueError("Combination termination requires at least 2 conditions")
|
|
if not config.operator:
|
|
raise ValueError("Combination termination requires an operator (and/or)")
|
|
|
|
# Load first two conditions
|
|
conditions = [await self.load_termination(cond) for cond in config.conditions[:2]]
|
|
result = conditions[0] & conditions[1] if config.operator == "and" else conditions[0] | conditions[1]
|
|
|
|
# Process remaining conditions if any
|
|
for condition in config.conditions[2:]:
|
|
next_condition = await self.load_termination(condition)
|
|
result = result & next_condition if config.operator == "and" else result | next_condition
|
|
|
|
return result
|
|
|
|
elif config.termination_type == TerminationTypes.MAX_MESSAGES:
|
|
if config.max_messages is None:
|
|
raise ValueError("max_messages parameter required for MaxMessageTermination")
|
|
return MaxMessageTermination(max_messages=config.max_messages)
|
|
|
|
elif config.termination_type == TerminationTypes.STOP_MESSAGE:
|
|
return StopMessageTermination()
|
|
|
|
elif config.termination_type == TerminationTypes.TEXT_MENTION:
|
|
if not config.text:
|
|
raise ValueError("text parameter required for TextMentionTermination")
|
|
return TextMentionTermination(text=config.text)
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported termination type: {config.termination_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create termination condition: {str(e)}")
|
|
raise ValueError(f"Termination condition creation failed: {str(e)}") from e
|
|
|
|
async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = None) -> TeamComponent:
|
|
"""Create team instance from configuration."""
|
|
try:
|
|
# Load participants (agents) with input_func
|
|
participants = []
|
|
for participant in config.participants:
|
|
agent = await self.load(participant, input_func=input_func)
|
|
participants.append(agent)
|
|
|
|
# Load termination condition if specified
|
|
termination = None
|
|
if config.termination_condition:
|
|
termination = await self.load(config.termination_condition)
|
|
|
|
# Create team based on type
|
|
if config.team_type == TeamTypes.ROUND_ROBIN:
|
|
return RoundRobinGroupChat(participants=participants, termination_condition=termination)
|
|
elif config.team_type == TeamTypes.SELECTOR:
|
|
model_client = await self.load(config.model_client)
|
|
if not model_client:
|
|
raise ValueError("SelectorGroupChat requires a model_client")
|
|
selector_prompt = config.selector_prompt if config.selector_prompt else DEFAULT_SELECTOR_PROMPT
|
|
return SelectorGroupChat(
|
|
participants=participants,
|
|
model_client=model_client,
|
|
termination_condition=termination,
|
|
selector_prompt=selector_prompt,
|
|
)
|
|
elif config.team_type == TeamTypes.MAGENTIC_ONE:
|
|
model_client = await self.load(config.model_client)
|
|
if not model_client:
|
|
raise ValueError("MagenticOneGroupChat requires a model_client")
|
|
return MagenticOneGroupChat(
|
|
participants=participants,
|
|
model_client=model_client,
|
|
termination_condition=termination if termination is not None else None,
|
|
max_turns=config.max_turns if config.max_turns is not None else 20,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported team type: {config.team_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create team {config.name}: {str(e)}")
|
|
raise ValueError(f"Team creation failed: {str(e)}") from e
|
|
|
|
async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] = None) -> AgentComponent:
|
|
"""Create agent instance from configuration."""
|
|
|
|
model_client = None
|
|
system_message = None
|
|
tools = []
|
|
if hasattr(config, "system_message") and config.system_message:
|
|
system_message = config.system_message
|
|
if hasattr(config, "model_client") and config.model_client:
|
|
model_client = await self.load(config.model_client)
|
|
if hasattr(config, "tools") and config.tools:
|
|
for tool_config in config.tools:
|
|
tool = await self.load(tool_config)
|
|
tools.append(tool)
|
|
|
|
try:
|
|
if config.agent_type == AgentTypes.USERPROXY:
|
|
return UserProxyAgent(
|
|
name=config.name,
|
|
description=config.description or "A human user",
|
|
input_func=input_func, # Pass through to UserProxyAgent
|
|
)
|
|
elif config.agent_type == AgentTypes.ASSISTANT:
|
|
system_message = config.system_message if config.system_message else "You are a helpful assistant"
|
|
|
|
return AssistantAgent(
|
|
name=config.name,
|
|
description=config.description or "A helpful assistant",
|
|
model_client=model_client,
|
|
tools=tools,
|
|
system_message=system_message,
|
|
)
|
|
elif config.agent_type == AgentTypes.MULTIMODAL_WEBSURFER:
|
|
return MultimodalWebSurfer(
|
|
name=config.name,
|
|
model_client=model_client,
|
|
headless=config.headless if config.headless is not None else True,
|
|
debug_dir=config.logs_dir if config.logs_dir is not None else None,
|
|
downloads_folder=config.logs_dir if config.logs_dir is not None else None,
|
|
to_save_screenshots=config.to_save_screenshots if config.to_save_screenshots is not None else False,
|
|
use_ocr=config.use_ocr if config.use_ocr is not None else False,
|
|
animate_actions=config.animate_actions if config.animate_actions is not None else False,
|
|
)
|
|
elif config.agent_type == AgentTypes.FILE_SURFER:
|
|
return FileSurfer(
|
|
name=config.name,
|
|
model_client=model_client,
|
|
)
|
|
elif config.agent_type == AgentTypes.MAGENTIC_ONE_CODER:
|
|
return MagenticOneCoderAgent(
|
|
name=config.name,
|
|
model_client=model_client,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported agent type: {config.agent_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create agent {config.name}: {str(e)}")
|
|
raise ValueError(f"Agent creation failed: {str(e)}") from e
|
|
|
|
async def load_model(self, config: ModelConfig) -> ModelComponent:
|
|
"""Create model instance from configuration."""
|
|
try:
|
|
# Check cache first
|
|
cache_key = str(config.model_dump())
|
|
if cache_key in self._model_cache:
|
|
logger.debug(f"Using cached model for {config.model}")
|
|
return self._model_cache[cache_key]
|
|
|
|
if config.model_type == ModelTypes.OPENAI:
|
|
args = {
|
|
"model": config.model,
|
|
"api_key": config.api_key,
|
|
"base_url": config.base_url,
|
|
}
|
|
|
|
if hasattr(config, "model_capabilities") and config.model_capabilities is not None:
|
|
args["model_capabilities"] = config.model_capabilities
|
|
|
|
model = OpenAIChatCompletionClient(**args)
|
|
self._model_cache[cache_key] = model
|
|
return model
|
|
elif config.model_type == ModelTypes.AZUREOPENAI:
|
|
model = AzureOpenAIChatCompletionClient(
|
|
azure_deployment=config.azure_deployment,
|
|
model=config.model,
|
|
api_version=config.api_version,
|
|
azure_endpoint=config.azure_endpoint,
|
|
api_key=config.api_key,
|
|
)
|
|
self._model_cache[cache_key] = model
|
|
return model
|
|
else:
|
|
raise ValueError(f"Unsupported model type: {config.model_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create model {config.model}: {str(e)}")
|
|
raise ValueError(f"Model creation failed: {str(e)}") from e
|
|
|
|
async def load_tool(self, config: ToolConfig) -> ToolComponent:
|
|
"""Create tool instance from configuration."""
|
|
try:
|
|
# Validate required fields
|
|
if not all([config.name, config.description, config.content, config.tool_type]):
|
|
raise ValueError("Tool configuration missing required fields")
|
|
|
|
# Check cache first
|
|
cache_key = str(config.model_dump())
|
|
if cache_key in self._tool_cache:
|
|
logger.debug(f"Using cached tool '{config.name}'")
|
|
return self._tool_cache[cache_key]
|
|
|
|
if config.tool_type == ToolTypes.PYTHON_FUNCTION:
|
|
tool = FunctionTool(
|
|
name=config.name, description=config.description, func=self._func_from_string(config.content)
|
|
)
|
|
self._tool_cache[cache_key] = tool
|
|
return tool
|
|
else:
|
|
raise ValueError(f"Unsupported tool type: {config.tool_type}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create tool '{config.name}': {str(e)}")
|
|
raise
|
|
|
|
async def _load_from_file(self, path: Union[str, Path]) -> dict:
|
|
"""Load configuration from JSON or YAML file."""
|
|
path = Path(path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"Config file not found: {path}")
|
|
|
|
try:
|
|
async with aiofiles.open(path) as f:
|
|
content = await f.read()
|
|
if path.suffix == ".json":
|
|
return json.loads(content)
|
|
elif path.suffix in (".yml", ".yaml"):
|
|
return yaml.safe_load(content)
|
|
else:
|
|
raise ValueError(f"Unsupported file format: {path.suffix}")
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to load file {path}: {str(e)}") from e
|
|
|
|
def _func_from_string(self, content: str) -> callable:
|
|
"""Convert function string to callable."""
|
|
try:
|
|
namespace = {}
|
|
exec(content, namespace)
|
|
for item in namespace.values():
|
|
if callable(item) and not isinstance(item, type):
|
|
return item
|
|
raise ValueError("No function found in provided code")
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to create function: {str(e)}") from e
|
|
|
|
def _is_version_supported(self, component_type: ComponentTypes, ver: str) -> bool:
|
|
"""Check if version is supported for component type."""
|
|
try:
|
|
version = Version(ver)
|
|
supported = [Version(v) for v in self.SUPPORTED_VERSIONS[component_type]]
|
|
return any(version == v for v in supported)
|
|
except ValueError:
|
|
return False
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Cleanup resources and clear caches."""
|
|
for model in self._model_cache.values():
|
|
if hasattr(model, "cleanup"):
|
|
await model.cleanup()
|
|
|
|
for tool in self._tool_cache.values():
|
|
if hasattr(tool, "cleanup"):
|
|
await tool.cleanup()
|
|
|
|
self._model_cache.clear()
|
|
self._tool_cache.clear()
|
|
self._last_cache_clear = datetime.now()
|
|
logger.info("Cleared all component caches")
|