mirror of https://github.com/microsoft/autogen.git
Add AOAI Support in AGS (#4718)
* add oai support, improve component config typing, minor updates to docs, update ags tests * faq updates * update faq, add model_capabilities * update faq
This commit is contained in:
parent
e2df4e24db
commit
6a4a11042c
|
@ -13,8 +13,54 @@ A: You can specify the directory where files are stored by setting the `--appdir
|
|||
|
||||
## Q: Can I use other models with AutoGen Studio?
|
||||
|
||||
Yes. AutoGen standardizes on the openai model api format, and you can use any api server that offers an openai compliant endpoint. In the AutoGen Studio UI, each agent has an `model_client` field where you can input your model endpoint details including `model`, `api key`, `base url`, `model type` and `api version`. For Azure OpenAI models, you can find these details in the Azure portal. Note that for Azure OpenAI, the `model name` is the deployment id or engine, and the `model type` is "azure".
|
||||
For other OSS models, we recommend using a server such as vllm, LMStudio, Ollama, to instantiate an openai compliant endpoint.
|
||||
Yes. AutoGen standardizes on the openai model api format, and you can use any api server that offers an openai compliant endpoint.
|
||||
|
||||
AutoGen Studio is based on declaritive specifications which applies to models as well. Agents can include a model_client field which specifies the model endpoint details including `model`, `api_key`, `base_url`, `model type`.
|
||||
|
||||
An example of the openai model client is shown below:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4o-mini",
|
||||
"model_type": "OpenAIChatCompletionClient",
|
||||
"api_key": "your-api-key"
|
||||
}
|
||||
```
|
||||
|
||||
An example of the azure openai model client is shown below:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4o-mini",
|
||||
"model_type": "AzureOpenAIChatCompletionClient",
|
||||
"azure_deployment": "gpt-4o-mini",
|
||||
"api_version": "2024-02-15-preview",
|
||||
"azure_endpoint": "https://your-endpoint.openai.azure.com/",
|
||||
"api_key": "your-api-key",
|
||||
"component_type": "model"
|
||||
}
|
||||
```
|
||||
|
||||
Have a local model server like Ollama, vLLM or LMStudio that provide an OpenAI compliant endpoint? You can use that as well.
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
|
||||
"model_type": "OpenAIChatCompletionClient",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"api_version": "1.0",
|
||||
"component_type": "model",
|
||||
"model_capabilities": {
|
||||
"vision": false,
|
||||
"function_calling": false,
|
||||
"json_output": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```{caution}
|
||||
It is important that you add the `model_capabilities` field to the model client specification for custom models. This is used by the framework instantiate and use the model correctly.
|
||||
```
|
||||
|
||||
## Q: The server starts but I can't access the UI
|
||||
|
||||
|
|
|
@ -13,27 +13,27 @@ The expected usage behavior is that developers use the provided Team Builder int
|
|||
|
||||
## Building an Agent Team
|
||||
|
||||
AutoGen Studio is tied very closely with all of the component abstractions provided by AutoGen AgentChat. This includes - {py:class}`~autogen_agentchat.teams`, {py:class}`~autogen_agentchat.agents`, {py:class}`~autogen_core.models`, {py:class}`~autogen_core.tools`, {py:class}`~autogen_agentchat.conditions`.
|
||||
AutoGen Studio is tied very closely with all of the component abstractions provided by AutoGen AgentChat. This includes - {py:class}`~autogen_agentchat.teams`, {py:class}`~autogen_agentchat.agents`, {py:class}`~autogen_core.models`, {py:class}`~autogen_core.tools`, termination {py:class}`~autogen_agentchat.conditions`.
|
||||
|
||||
Users can define these components in the Team Builder interface either via a declarative specification or by dragging and dropping components from a component library.
|
||||
|
||||
## Testing an Agent Team
|
||||
## Interactively Running Teams
|
||||
|
||||
AutoGen Studio Playground allows users to interactively test teams on tasks and review resulting artifacts (such as images, code, and documents).
|
||||
AutoGen Studio Playground allows users to interactively test teams on tasks and review resulting artifacts (such as images, code, and text).
|
||||
|
||||
Users can also review the “inner monologue” of team as they address tasks, and view profiling information such as costs associated with the run (such as number of turns, number of tokens etc.), and agent actions (such as whether tools were called and the outcomes of code execution).
|
||||
|
||||
## Importing and Reusing Team Configurations
|
||||
|
||||
AutoGen Studio provides a Gallery view where users can import components from 3rd party community sources. This allows users to reuse and share team configurations with others.
|
||||
AutoGen Studio provides a Gallery view which provides a built-in default gallery. A Gallery is simply is a collection of components - teams, agents, models tools etc. Furthermore, users can import components from 3rd party community sources either by providing a URL to a JSON Gallery spec or pasting in the gallery JSON. This allows users to reuse and share team configurations with others.
|
||||
|
||||
- Gallery -> New Gallery -> Import
|
||||
- Set as default gallery
|
||||
- Reuse components in Team Builder
|
||||
- Set as default gallery (in side bar, by clicking pin icon)
|
||||
- Reuse components in Team Builder. Team Builder -> Sidebar -> From Gallery
|
||||
|
||||
### Using AutoGen Studio Teams in a Python Application
|
||||
|
||||
An exported team can be easily integrated into any Python application using the `TeamManager` class with just two lines of code. Underneath, the `TeamManager` rehydrates the workflow specification into AutoGen agents that are subsequently used to address tasks.
|
||||
An exported team can be easily integrated into any Python application using the `TeamManager` class with just two lines of code. Underneath, the `TeamManager` rehydrates the team specification into AutoGen AgentChat agents that are subsequently used to address tasks.
|
||||
|
||||
```python
|
||||
|
||||
|
@ -44,12 +44,14 @@ result_stream = tm.run(task="What is the weather in New York?", team_config="te
|
|||
|
||||
```
|
||||
|
||||
To export a team configuration, click on the export button in the Team Builder interface. This will generate a JSON file that can be used to rehydrate the team in a Python application.
|
||||
|
||||
<!-- ### Deploying AutoGen Studio Teams as APIs
|
||||
|
||||
The team can be launched as an API endpoint from the command line using the autogenstudio commandline tool.
|
||||
|
||||
```bash
|
||||
autogenstudio serve --workflow=workflow.json --port=5000
|
||||
autogenstudio serve --team=team.json --port=5000
|
||||
```
|
||||
|
||||
Similarly, the workflow launch command above can be wrapped into a Dockerfile that can be deployed on cloud services like Azure Container Apps or Azure Web Apps. -->
|
||||
Similarly, the team launch command above can be wrapped into a Dockerfile that can be deployed on cloud services like Azure Container Apps or Azure Web Apps. -->
|
||||
|
|
|
@ -18,26 +18,37 @@ from autogen_agentchat.conditions import (
|
|||
TokenUsageTermination,
|
||||
)
|
||||
from autogen_agentchat.teams import MagenticOneGroupChat, RoundRobinGroupChat, SelectorGroupChat
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.agents.file_surfer import FileSurfer
|
||||
from autogen_ext.agents.magentic_one import MagenticOneCoderAgent
|
||||
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient, OpenAIChatCompletionClient
|
||||
|
||||
from ..datamodel.types import (
|
||||
AgentConfig,
|
||||
AgentTypes,
|
||||
AssistantAgentConfig,
|
||||
AzureOpenAIModelConfig,
|
||||
CombinationTerminationConfig,
|
||||
ComponentConfig,
|
||||
ComponentConfigInput,
|
||||
ComponentTypes,
|
||||
MagenticOneTeamConfig,
|
||||
MaxMessageTerminationConfig,
|
||||
ModelConfig,
|
||||
ModelTypes,
|
||||
MultimodalWebSurferAgentConfig,
|
||||
OpenAIModelConfig,
|
||||
RoundRobinTeamConfig,
|
||||
SelectorTeamConfig,
|
||||
TeamConfig,
|
||||
TeamTypes,
|
||||
TerminationConfig,
|
||||
TerminationTypes,
|
||||
TextMentionTerminationConfig,
|
||||
ToolConfig,
|
||||
ToolTypes,
|
||||
UserProxyAgentConfig,
|
||||
)
|
||||
from ..utils.utils import Version
|
||||
|
||||
|
@ -45,7 +56,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat]
|
||||
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer, UserProxyAgent, FileSurfer, MagenticOneCoderAgent]
|
||||
ModelComponent = Union[OpenAIChatCompletionClient]
|
||||
ModelComponent = Union[OpenAIChatCompletionClient, AzureOpenAIChatCompletionClient]
|
||||
ToolComponent = Union[FunctionTool] # Will grow with more tool types
|
||||
TerminationComponent = Union[
|
||||
MaxMessageTermination,
|
||||
|
@ -87,7 +98,7 @@ class ComponentFactory:
|
|||
}
|
||||
|
||||
def __init__(self):
|
||||
self._model_cache: Dict[str, OpenAIChatCompletionClient] = {}
|
||||
self._model_cache: Dict[str, ModelComponent] = {}
|
||||
self._tool_cache: Dict[str, FunctionTool] = {}
|
||||
self._last_cache_clear = datetime.now()
|
||||
|
||||
|
@ -172,23 +183,58 @@ class ComponentFactory:
|
|||
return components
|
||||
|
||||
def _dict_to_config(self, config_dict: dict) -> ComponentConfig:
|
||||
"""Convert dictionary to appropriate config type based on component_type"""
|
||||
"""Convert dictionary to appropriate config type based on component_type and type discriminator"""
|
||||
if "component_type" not in config_dict:
|
||||
raise ValueError("component_type is required in configuration")
|
||||
|
||||
config_types = {
|
||||
ComponentTypes.TEAM: TeamConfig,
|
||||
ComponentTypes.AGENT: AgentConfig,
|
||||
ComponentTypes.MODEL: ModelConfig,
|
||||
component_type = ComponentTypes(config_dict["component_type"])
|
||||
|
||||
# Define mapping structure
|
||||
type_mappings = {
|
||||
ComponentTypes.MODEL: {
|
||||
"discriminator": "model_type",
|
||||
ModelTypes.OPENAI.value: OpenAIModelConfig,
|
||||
ModelTypes.AZUREOPENAI.value: AzureOpenAIModelConfig,
|
||||
},
|
||||
ComponentTypes.AGENT: {
|
||||
"discriminator": "agent_type",
|
||||
AgentTypes.ASSISTANT.value: AssistantAgentConfig,
|
||||
AgentTypes.USERPROXY.value: UserProxyAgentConfig,
|
||||
AgentTypes.MULTIMODAL_WEBSURFER.value: MultimodalWebSurferAgentConfig,
|
||||
},
|
||||
ComponentTypes.TEAM: {
|
||||
"discriminator": "team_type",
|
||||
TeamTypes.ROUND_ROBIN.value: RoundRobinTeamConfig,
|
||||
TeamTypes.SELECTOR.value: SelectorTeamConfig,
|
||||
TeamTypes.MAGENTIC_ONE.value: MagenticOneTeamConfig,
|
||||
},
|
||||
ComponentTypes.TOOL: ToolConfig,
|
||||
ComponentTypes.TERMINATION: TerminationConfig, # Add mapping for termination
|
||||
ComponentTypes.TERMINATION: {
|
||||
"discriminator": "termination_type",
|
||||
TerminationTypes.MAX_MESSAGES.value: MaxMessageTerminationConfig,
|
||||
TerminationTypes.TEXT_MENTION.value: TextMentionTerminationConfig,
|
||||
TerminationTypes.COMBINATION.value: CombinationTerminationConfig,
|
||||
},
|
||||
}
|
||||
|
||||
component_type = ComponentTypes(config_dict["component_type"])
|
||||
config_class = config_types.get(component_type)
|
||||
mapping = type_mappings.get(component_type)
|
||||
if not mapping:
|
||||
raise ValueError(f"Unknown component type: {component_type}")
|
||||
|
||||
# Handle simple cases (no discriminator)
|
||||
if isinstance(mapping, type):
|
||||
return mapping(**config_dict)
|
||||
|
||||
# Get discriminator field value
|
||||
discriminator = mapping["discriminator"]
|
||||
if discriminator not in config_dict:
|
||||
raise ValueError(f"Missing {discriminator} in configuration")
|
||||
|
||||
type_value = config_dict[discriminator]
|
||||
config_class = mapping.get(type_value)
|
||||
|
||||
if not config_class:
|
||||
raise ValueError(f"Unknown component type: {component_type}")
|
||||
raise ValueError(f"Unknown {discriminator}: {type_value}")
|
||||
|
||||
return config_class(**config_dict)
|
||||
|
||||
|
@ -241,11 +287,6 @@ class ComponentFactory:
|
|||
agent = await self.load(participant, input_func=input_func)
|
||||
participants.append(agent)
|
||||
|
||||
# Load model client if specified
|
||||
model_client = None
|
||||
if config.model_client:
|
||||
model_client = await self.load(config.model_client)
|
||||
|
||||
# Load termination condition if specified
|
||||
termination = None
|
||||
if config.termination_condition:
|
||||
|
@ -255,6 +296,7 @@ class ComponentFactory:
|
|||
if config.team_type == TeamTypes.ROUND_ROBIN:
|
||||
return RoundRobinGroupChat(participants=participants, termination_condition=termination)
|
||||
elif config.team_type == TeamTypes.SELECTOR:
|
||||
model_client = await self.load(config.model_client)
|
||||
if not model_client:
|
||||
raise ValueError("SelectorGroupChat requires a model_client")
|
||||
selector_prompt = config.selector_prompt if config.selector_prompt else DEFAULT_SELECTOR_PROMPT
|
||||
|
@ -265,6 +307,7 @@ class ComponentFactory:
|
|||
selector_prompt=selector_prompt,
|
||||
)
|
||||
elif config.team_type == TeamTypes.MAGENTIC_ONE:
|
||||
model_client = await self.load(config.model_client)
|
||||
if not model_client:
|
||||
raise ValueError("MagenticOneGroupChat requires a model_client")
|
||||
return MagenticOneGroupChat(
|
||||
|
@ -282,14 +325,15 @@ class ComponentFactory:
|
|||
|
||||
async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] = None) -> AgentComponent:
|
||||
"""Create agent instance from configuration."""
|
||||
|
||||
system_message = config.system_message if config.system_message else "You are a helpful assistant"
|
||||
|
||||
try:
|
||||
# Load model client if specified
|
||||
model_client = None
|
||||
if config.model_client:
|
||||
model_client = await self.load(config.model_client)
|
||||
|
||||
system_message = config.system_message if config.system_message else "You are a helpful assistant"
|
||||
|
||||
# Load tools if specified
|
||||
tools = []
|
||||
if config.tools:
|
||||
|
@ -304,6 +348,8 @@ class ComponentFactory:
|
|||
input_func=input_func, # Pass through to UserProxyAgent
|
||||
)
|
||||
elif config.agent_type == AgentTypes.ASSISTANT:
|
||||
system_message = config.system_message if config.system_message else "You are a helpful assistant"
|
||||
|
||||
return AssistantAgent(
|
||||
name=config.name,
|
||||
description=config.description or "A helpful assistant",
|
||||
|
@ -349,7 +395,26 @@ class ComponentFactory:
|
|||
return self._model_cache[cache_key]
|
||||
|
||||
if config.model_type == ModelTypes.OPENAI:
|
||||
model = OpenAIChatCompletionClient(model=config.model, api_key=config.api_key, base_url=config.base_url)
|
||||
args = {
|
||||
"model": config.model,
|
||||
"api_key": config.api_key,
|
||||
"base_url": config.base_url,
|
||||
}
|
||||
|
||||
if hasattr(config, "model_capabilities") and config.model_capabilities is not None:
|
||||
args["model_capabilities"] = config.model_capabilities
|
||||
|
||||
model = OpenAIChatCompletionClient(**args)
|
||||
self._model_cache[cache_key] = model
|
||||
return model
|
||||
elif config.model_type == ModelTypes.AZUREOPENAI:
|
||||
model = AzureOpenAIChatCompletionClient(
|
||||
azure_deployment=config.azure_deployment,
|
||||
model=config.model,
|
||||
api_version=config.api_version,
|
||||
azure_endpoint=config.azure_endpoint,
|
||||
api_key=config.api_key,
|
||||
)
|
||||
self._model_cache[cache_key] = model
|
||||
return model
|
||||
else:
|
||||
|
|
|
@ -118,7 +118,7 @@ class Tool(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[ToolConfig, dict] = Field(default_factory=ToolConfig, sa_column=Column(JSON))
|
||||
config: Union[ToolConfig, dict] = Field(sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(back_populates="tools", link_model=AgentToolLink)
|
||||
|
||||
|
||||
|
@ -135,7 +135,7 @@ class Model(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[ModelConfig, dict] = Field(default_factory=ModelConfig, sa_column=Column(JSON))
|
||||
config: Union[ModelConfig, dict] = Field(sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(back_populates="models", link_model=AgentModelLink)
|
||||
|
||||
|
||||
|
@ -152,7 +152,7 @@ class Team(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[TeamConfig, dict] = Field(default_factory=TeamConfig, sa_column=Column(JSON))
|
||||
config: Union[TeamConfig, dict] = Field(sa_column=Column(JSON))
|
||||
agents: List["Agent"] = Relationship(back_populates="teams", link_model=TeamAgentLink)
|
||||
|
||||
|
||||
|
@ -169,7 +169,7 @@ class Agent(SQLModel, table=True):
|
|||
) # pylint: disable=not-callable
|
||||
user_id: Optional[str] = None
|
||||
version: Optional[str] = "0.0.1"
|
||||
config: Union[AgentConfig, dict] = Field(default_factory=AgentConfig, sa_column=Column(JSON))
|
||||
config: Union[AgentConfig, dict] = Field(sa_column=Column(JSON))
|
||||
tools: List[Tool] = Relationship(back_populates="agents", link_model=AgentToolLink)
|
||||
models: List[Model] = Relationship(back_populates="agents", link_model=AgentModelLink)
|
||||
teams: List[Team] = Relationship(back_populates="agents", link_model=TeamAgentLink)
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from autogen_agentchat.base import TaskResult
|
||||
from pydantic import BaseModel
|
||||
from autogen_core.models import ModelCapabilities
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelTypes(str, Enum):
|
||||
OPENAI = "OpenAIChatCompletionClient"
|
||||
AZUREOPENAI = "AzureOpenAIChatCompletionClient"
|
||||
|
||||
|
||||
class ToolTypes(str, Enum):
|
||||
|
@ -56,12 +58,30 @@ class MessageConfig(BaseModel):
|
|||
message_type: Optional[str] = "text"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
class BaseModelConfig(BaseConfig):
|
||||
model: str
|
||||
model_type: ModelTypes
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
component_type: ComponentTypes = ComponentTypes.MODEL
|
||||
model_capabilities: Optional[ModelCapabilities] = None
|
||||
|
||||
|
||||
class OpenAIModelConfig(BaseModelConfig):
|
||||
model_type: ModelTypes = ModelTypes.OPENAI
|
||||
|
||||
|
||||
class AzureOpenAIModelConfig(BaseModelConfig):
|
||||
azure_deployment: str
|
||||
model: str
|
||||
api_version: str
|
||||
azure_endpoint: str
|
||||
azure_ad_token_provider: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_type: ModelTypes = ModelTypes.AZUREOPENAI
|
||||
|
||||
|
||||
ModelConfig = OpenAIModelConfig | AzureOpenAIModelConfig
|
||||
|
||||
|
||||
class ToolConfig(BaseConfig):
|
||||
|
@ -72,43 +92,100 @@ class ToolConfig(BaseConfig):
|
|||
component_type: ComponentTypes = ComponentTypes.TOOL
|
||||
|
||||
|
||||
class AgentConfig(BaseConfig):
|
||||
class BaseAgentConfig(BaseConfig):
|
||||
name: str
|
||||
agent_type: AgentTypes
|
||||
system_message: Optional[str] = None
|
||||
model_client: Optional[ModelConfig] = None
|
||||
tools: Optional[List[ToolConfig]] = None
|
||||
description: Optional[str] = None
|
||||
component_type: ComponentTypes = ComponentTypes.AGENT
|
||||
headless: Optional[bool] = None
|
||||
logs_dir: Optional[str] = None
|
||||
to_save_screenshots: Optional[bool] = None
|
||||
use_ocr: Optional[bool] = None
|
||||
animate_actions: Optional[bool] = None
|
||||
|
||||
|
||||
class TerminationConfig(BaseConfig):
|
||||
class AssistantAgentConfig(BaseAgentConfig):
|
||||
agent_type: AgentTypes = AgentTypes.ASSISTANT
|
||||
model_client: ModelConfig
|
||||
tools: Optional[List[ToolConfig]] = None
|
||||
system_message: Optional[str] = None
|
||||
|
||||
|
||||
class UserProxyAgentConfig(BaseAgentConfig):
|
||||
agent_type: AgentTypes = AgentTypes.USERPROXY
|
||||
|
||||
|
||||
class MultimodalWebSurferAgentConfig(BaseAgentConfig):
|
||||
agent_type: AgentTypes = AgentTypes.MULTIMODAL_WEBSURFER
|
||||
model_client: ModelConfig
|
||||
headless: bool = True
|
||||
logs_dir: str = "logs"
|
||||
to_save_screenshots: bool = False
|
||||
use_ocr: bool = False
|
||||
animate_actions: bool = False
|
||||
tools: Optional[List[ToolConfig]] = None
|
||||
|
||||
|
||||
AgentConfig = AssistantAgentConfig | UserProxyAgentConfig | MultimodalWebSurferAgentConfig
|
||||
|
||||
|
||||
class BaseTerminationConfig(BaseConfig):
|
||||
termination_type: TerminationTypes
|
||||
# Fields for basic terminations
|
||||
max_messages: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
# Fields for combinations
|
||||
operator: Optional[Literal["and", "or"]] = None
|
||||
conditions: Optional[List["TerminationConfig"]] = None
|
||||
component_type: ComponentTypes = ComponentTypes.TERMINATION
|
||||
|
||||
|
||||
class TeamConfig(BaseConfig):
|
||||
class MaxMessageTerminationConfig(BaseTerminationConfig):
|
||||
termination_type: TerminationTypes = TerminationTypes.MAX_MESSAGES
|
||||
max_messages: int
|
||||
|
||||
|
||||
class TextMentionTerminationConfig(BaseTerminationConfig):
|
||||
termination_type: TerminationTypes = TerminationTypes.TEXT_MENTION
|
||||
text: str
|
||||
|
||||
|
||||
class StopMessageTerminationConfig(BaseTerminationConfig):
|
||||
termination_type: TerminationTypes = TerminationTypes.STOP_MESSAGE
|
||||
|
||||
|
||||
class CombinationTerminationConfig(BaseTerminationConfig):
|
||||
termination_type: TerminationTypes = TerminationTypes.COMBINATION
|
||||
operator: str
|
||||
conditions: List["TerminationConfig"]
|
||||
|
||||
|
||||
TerminationConfig = (
|
||||
MaxMessageTerminationConfig
|
||||
| TextMentionTerminationConfig
|
||||
| CombinationTerminationConfig
|
||||
| StopMessageTerminationConfig
|
||||
)
|
||||
|
||||
|
||||
class BaseTeamConfig(BaseConfig):
|
||||
name: str
|
||||
participants: List[AgentConfig]
|
||||
team_type: TeamTypes
|
||||
model_client: Optional[ModelConfig] = None
|
||||
selector_prompt: Optional[str] = None
|
||||
termination_condition: Optional[TerminationConfig] = None
|
||||
component_type: ComponentTypes = ComponentTypes.TEAM
|
||||
max_turns: Optional[int] = None
|
||||
|
||||
|
||||
class RoundRobinTeamConfig(BaseTeamConfig):
|
||||
team_type: TeamTypes = TeamTypes.ROUND_ROBIN
|
||||
|
||||
|
||||
class SelectorTeamConfig(BaseTeamConfig):
|
||||
team_type: TeamTypes = TeamTypes.SELECTOR
|
||||
selector_prompt: Optional[str] = None
|
||||
model_client: ModelConfig
|
||||
|
||||
|
||||
class MagenticOneTeamConfig(BaseTeamConfig):
|
||||
team_type: TeamTypes = TeamTypes.MAGENTIC_ONE
|
||||
model_client: ModelConfig
|
||||
max_stalls: int = 3
|
||||
final_answer_prompt: Optional[str] = None
|
||||
|
||||
|
||||
TeamConfig = RoundRobinTeamConfig | SelectorTeamConfig | MagenticOneTeamConfig
|
||||
|
||||
|
||||
class TeamResult(BaseModel):
|
||||
task_result: TaskResult
|
||||
usage: str
|
||||
|
|
|
@ -276,7 +276,8 @@ class WebSocketManager:
|
|||
if run_id not in self._closed_connections:
|
||||
error_result = TeamResult(
|
||||
task_result=TaskResult(
|
||||
messages=[TextMessage(source="system", content=str(error))], stop_reason="error"
|
||||
messages=[TextMessage(source="system", content=str(error))],
|
||||
stop_reason="An error occurred while processing this run",
|
||||
),
|
||||
usage="",
|
||||
duration=0,
|
||||
|
|
|
@ -117,6 +117,9 @@ const RunView: React.FC<RunViewProps> = ({
|
|||
}
|
||||
};
|
||||
|
||||
const lastResultMessage = run.team_result?.task_result.messages.slice(-1)[0];
|
||||
const lastMessage = run.messages.slice(-1)[0];
|
||||
|
||||
return (
|
||||
<div className="space-y-6 mr-2 ">
|
||||
{/* Run Header */}
|
||||
|
@ -193,14 +196,23 @@ const RunView: React.FC<RunViewProps> = ({
|
|||
Stop reason: {run.team_result?.task_result?.stop_reason}
|
||||
</div>
|
||||
|
||||
<TruncatableText
|
||||
key={"_" + run.id}
|
||||
textThreshold={700}
|
||||
content={
|
||||
run.messages[run.messages.length - 1]?.config?.content + ""
|
||||
}
|
||||
className="break-all"
|
||||
/>
|
||||
{lastMessage ? (
|
||||
<TruncatableText
|
||||
key={"_" + run.id}
|
||||
textThreshold={700}
|
||||
content={
|
||||
run.messages[run.messages.length - 1]?.config?.content +
|
||||
""
|
||||
}
|
||||
className="break-all"
|
||||
/>
|
||||
) : (
|
||||
<>
|
||||
{lastResultMessage && (
|
||||
<RenderMessage message={lastResultMessage} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
@ -16,9 +16,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"task_result=TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15), content=[FunctionCall(id='call_jcgtAVlBvTFzVpPxKX88Xsa4', arguments='{\"city\":\"New York\"}', name='get_weather')], type='ToolCallMessage'), ToolCallResultMessage(source='writing_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_jcgtAVlBvTFzVpPxKX88Xsa4')], type='ToolCallResultMessage'), TextMessage(source='writing_agent', models_usage=None, content='The weather in New York is 73 degrees and Sunny.', type='TextMessage'), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=103, completion_tokens=14), content='The current weather in New York is 73 degrees and sunny.', type='TextMessage')], stop_reason='Maximum number of messages 5 reached, current message count: 5') usage='' duration=5.103050947189331\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogenstudio.teammanager import TeamManager\n",
|
||||
"\n",
|
||||
|
@ -29,9 +37,22 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='user' models_usage=None content='What is the weather in New York?' type='TextMessage'\n",
|
||||
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15) content=[FunctionCall(id='call_EwdwWogp5jDKdB7t9WGCNjZW', arguments='{\"city\":\"New York\"}', name='get_weather')] type='ToolCallMessage'\n",
|
||||
"source='writing_agent' models_usage=None content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_EwdwWogp5jDKdB7t9WGCNjZW')] type='ToolCallResultMessage'\n",
|
||||
"source='writing_agent' models_usage=None content='The weather in New York is 73 degrees and Sunny.' type='TextMessage'\n",
|
||||
"source='writing_agent' models_usage=RequestUsage(prompt_tokens=103, completion_tokens=14) content='The weather in New York is currently 73 degrees and sunny.' type='TextMessage'\n",
|
||||
"task_result=TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the weather in New York?', type='TextMessage'), ToolCallMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=65, completion_tokens=15), content=[FunctionCall(id='call_EwdwWogp5jDKdB7t9WGCNjZW', arguments='{\"city\":\"New York\"}', name='get_weather')], type='ToolCallMessage'), ToolCallResultMessage(source='writing_agent', models_usage=None, content=[FunctionExecutionResult(content='The weather in New York is 73 degrees and Sunny.', call_id='call_EwdwWogp5jDKdB7t9WGCNjZW')], type='ToolCallResultMessage'), TextMessage(source='writing_agent', models_usage=None, content='The weather in New York is 73 degrees and Sunny.', type='TextMessage'), TextMessage(source='writing_agent', models_usage=RequestUsage(prompt_tokens=103, completion_tokens=14), content='The weather in New York is currently 73 degrees and sunny.', type='TextMessage')], stop_reason='Maximum number of messages 5 reached, current message count: 5') usage='' duration=1.284574270248413\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result_stream = wm.run_stream(task=\"What is the weather in New York?\", team_config=\"team.json\")\n",
|
||||
"async for response in result_stream:\n",
|
||||
|
@ -49,7 +70,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -78,25 +99,26 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from sqlmodel import Session, text, select\n",
|
||||
"from autogenstudio.datamodel import Model, ModelConfig, ModelTypes, Team, TeamConfig, TeamTypes, Agent, AgentConfig, AgentTypes, Tool, ToolConfig, LinkTypes,ToolTypes\n",
|
||||
"from autogenstudio.datamodel.types import ModelTypes, TeamTypes, AgentTypes, ToolConfig, ToolTypes, OpenAIModelConfig, RoundRobinTeamConfig, MaxMessageTerminationConfig, AssistantAgentConfig, TerminationTypes\n",
|
||||
"\n",
|
||||
"user_id = \"guestuser@gmail.com\"\n",
|
||||
"from autogenstudio.datamodel import ModelConfig, Model, TeamConfig, Team, Tool, Agent, AgentConfig, TerminationConfig, TerminationTypes, ModelTypes, TeamTypes, AgentTypes, ToolConfig, LinkTypes, TerminationTypes\n",
|
||||
"from autogenstudio.datamodel.db import Model, Team, Agent, Tool,LinkTypes\n",
|
||||
"\n",
|
||||
"gpt4_model = Model(user_id=user_id, config= ModelConfig(model=\"gpt-4o-2024-08-06\", model_type=ModelTypes.OPENAI).model_dump() )\n",
|
||||
"user_id = \"guestuser@gmail.com\" \n",
|
||||
"\n",
|
||||
"gpt4_model = Model(user_id=user_id, config= OpenAIModelConfig(model=\"gpt-4o-2024-08-06\", model_type=ModelTypes.OPENAI).model_dump() )\n",
|
||||
"\n",
|
||||
"weather_tool = Tool(user_id=user_id, config=ToolConfig(name=\"get_weather\", description=\"Get the weather for a city\", content=\"async def get_weather(city: str) -> str:\\n return f\\\"The weather in {city} is 73 degrees and Sunny.\\\"\",tool_type=ToolTypes.PYTHON_FUNCTION).model_dump() )\n",
|
||||
"\n",
|
||||
"adding_tool = Tool(user_id=user_id, config=ToolConfig(name=\"add\", description=\"Add two numbers\", content=\"async def add(a: int, b: int) -> int:\\n return a + b\", tool_type=ToolTypes.PYTHON_FUNCTION).model_dump() )\n",
|
||||
"\n",
|
||||
"writing_agent = Agent(user_id=user_id,\n",
|
||||
" config=AgentConfig(\n",
|
||||
" config=AssistantAgentConfig(\n",
|
||||
" name=\"writing_agent\",\n",
|
||||
" tools=[weather_tool.config],\n",
|
||||
" agent_type=AgentTypes.ASSISTANT,\n",
|
||||
|
@ -104,10 +126,10 @@
|
|||
" ).model_dump()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"team = Team(user_id=user_id, config=TeamConfig(\n",
|
||||
"team = Team(user_id=user_id, config=RoundRobinTeamConfig(\n",
|
||||
" name=\"weather_team\",\n",
|
||||
" participants=[writing_agent.config],\n",
|
||||
" termination_condition=TerminationConfig(termination_type=TerminationTypes.MAX_MESSAGES, max_messages=5).model_dump(),\n",
|
||||
" termination_condition=MaxMessageTerminationConfig(termination_type=TerminationTypes.MAX_MESSAGES, max_messages=5).model_dump(),\n",
|
||||
" team_type=TeamTypes.ROUND_ROBIN\n",
|
||||
" ).model_dump()\n",
|
||||
")\n",
|
||||
|
@ -155,7 +177,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -166,9 +188,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"message='Team Created Successfully' status=True data={'id': 4, 'updated_at': datetime.datetime(2024, 12, 15, 15, 52, 21, 674916), 'version': '0.0.1', 'created_at': datetime.datetime(2024, 12, 15, 15, 52, 21, 674910), 'user_id': 'user_id', 'config': {'version': '1.0.0', 'component_type': 'team', 'name': 'weather_team', 'participants': [{'version': '1.0.0', 'component_type': 'agent', 'name': 'writing_agent', 'agent_type': 'AssistantAgent', 'description': None, 'model_client': {'version': '1.0.0', 'component_type': 'model', 'model': 'gpt-4o-2024-08-06', 'model_type': 'OpenAIChatCompletionClient', 'api_key': None, 'base_url': None}, 'tools': [{'version': '1.0.0', 'component_type': 'tool', 'name': 'get_weather', 'description': 'Get the weather for a city', 'content': 'async def get_weather(city: str) -> str:\\n return f\"The weather in {city} is 73 degrees and Sunny.\"', 'tool_type': 'PythonFunction'}], 'system_message': None}], 'team_type': 'RoundRobinGroupChat', 'termination_condition': {'version': '1.0.0', 'component_type': 'termination', 'termination_type': 'MaxMessageTermination', 'max_messages': 5}, 'max_turns': None}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await config_manager.import_component(\"team.json\", user_id=\"user_id\", check_exists=True)\n",
|
||||
"print(result)"
|
||||
|
@ -176,9 +206,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"message='Directory import complete' status=True data=[{'component': 'team', 'status': True, 'message': 'Team Created Successfully', 'id': 5}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result = await config_manager.import_directory(\".\", user_id=\"user_id\", check_exists=False)\n",
|
||||
"print(result)"
|
||||
|
@ -186,9 +224,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"5 teams in database\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"all_teams = dbmanager.get(Team)\n",
|
||||
"print(len(all_teams.data), \"teams in database\")"
|
||||
|
@ -203,7 +249,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -248,52 +294,28 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"source='user' models_usage=None content='Plan a 3 day trip to Nepal.' type='TextMessage'\n",
|
||||
"source='planner_agent' models_usage=RequestUsage(prompt_tokens=45, completion_tokens=53) content='I recommend starting your trip in Kathmandu, where you can explore the historic Durbar Square and Pashupatinath Temple, then take a scenic flight over the Everest range, and finish your journey with a stunning hike in the Annapurna region.' type='TextMessage'\n",
|
||||
"source='local_agent' models_usage=RequestUsage(prompt_tokens=115, completion_tokens=53) content='I recommend starting your trip in Kathmandu, where you can explore the historic Durbar Square and Pashupatinath Temple, then take a scenic flight over the Everest range, and finish your journey with a stunning hike in the Annapurna region.' type='TextMessage'\n",
|
||||
"source='language_agent' models_usage=RequestUsage(prompt_tokens=199, completion_tokens=42) content=\"For your trip to Nepal, it's crucial to learn some phrases in Nepali since English is not widely spoken outside of major cities and tourist areas; even a simple phrasebook or translation app would be beneficial.\" type='TextMessage'\n",
|
||||
"source='travel_summary_agent' models_usage=RequestUsage(prompt_tokens=265, completion_tokens=298) content=\"Day 1: Begin your journey in Kathmandu, where you can visit the historic Durbar Square, a UNESCO World Heritage site that showcases intricate woodcarving and houses the iconic Kasthamandap Temple. From there, proceed to the sacred Pashupatinath Temple, a significant Hindu pilgrimage site on the banks of the holy Bagmati River.\\n\\nDay 2: Embark on an early morning scenic flight over the Everest range. This one-hour flight provides a breathtaking view of the world's highest peak along with other neighboring peaks. Standard flights depart from Tribhuvan International Airport between 6:30 AM to 7:30 AM depending on the weather. Spend the remainder of the day exploring the local markets in Kathmandu, sampling a variety of Nepalese cuisines and shopping for unique souvenirs.\\n\\nDay 3: Finally, take a short flight or drive to Pokhara, the gateway to the Annapurna region. Embark on a guided hike enjoying the stunning backdrop of the Annapurna ranges and the serene Phewa lake.\\n\\nRemember to bring along a phrasebook or translation app, as English is not widely spoken in Nepal, particularly outside of major cities and tourist hotspots. \\n\\nPack comfortable trekking gear, adequate water, medical and emergency supplies. It's also advisable to check on the weather updates, as conditions can change rapidly, particularly in mountainous areas. Enjoy your Nepal expedition!TERMINATE\" type='TextMessage'\n",
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Plan a 3 day trip to Nepal.', type='TextMessage'), TextMessage(source='planner_agent', models_usage=RequestUsage(prompt_tokens=45, completion_tokens=53), content='I recommend starting your trip in Kathmandu, where you can explore the historic Durbar Square and Pashupatinath Temple, then take a scenic flight over the Everest range, and finish your journey with a stunning hike in the Annapurna region.', type='TextMessage'), TextMessage(source='local_agent', models_usage=RequestUsage(prompt_tokens=115, completion_tokens=53), content='I recommend starting your trip in Kathmandu, where you can explore the historic Durbar Square and Pashupatinath Temple, then take a scenic flight over the Everest range, and finish your journey with a stunning hike in the Annapurna region.', type='TextMessage'), TextMessage(source='language_agent', models_usage=RequestUsage(prompt_tokens=199, completion_tokens=42), content=\"For your trip to Nepal, it's crucial to learn some phrases in Nepali since English is not widely spoken outside of major cities and tourist areas; even a simple phrasebook or translation app would be beneficial.\", type='TextMessage'), TextMessage(source='travel_summary_agent', models_usage=RequestUsage(prompt_tokens=265, completion_tokens=298), content=\"Day 1: Begin your journey in Kathmandu, where you can visit the historic Durbar Square, a UNESCO World Heritage site that showcases intricate woodcarving and houses the iconic Kasthamandap Temple. From there, proceed to the sacred Pashupatinath Temple, a significant Hindu pilgrimage site on the banks of the holy Bagmati River.\\n\\nDay 2: Embark on an early morning scenic flight over the Everest range. This one-hour flight provides a breathtaking view of the world's highest peak along with other neighboring peaks. Standard flights depart from Tribhuvan International Airport between 6:30 AM to 7:30 AM depending on the weather. Spend the remainder of the day exploring the local markets in Kathmandu, sampling a variety of Nepalese cuisines and shopping for unique souvenirs.\\n\\nDay 3: Finally, take a short flight or drive to Pokhara, the gateway to the Annapurna region. Embark on a guided hike enjoying the stunning backdrop of the Annapurna ranges and the serene Phewa lake.\\n\\nRemember to bring along a phrasebook or translation app, as English is not widely spoken in Nepal, particularly outside of major cities and tourist hotspots. \\n\\nPack comfortable trekking gear, adequate water, medical and emergency supplies. It's also advisable to check on the weather updates, as conditions can change rapidly, particularly in mountainous areas. Enjoy your Nepal expedition!TERMINATE\", type='TextMessage')], stop_reason=\"Text 'TERMINATE' mentioned\")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"result = group_chat.run_stream(task=\"Plan a 3 day trip to Nepal.\")\n",
|
||||
"async for response in result:\n",
|
||||
" print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Human in the Loop with a UserProxy Agent\n",
|
||||
"\n",
|
||||
"AutoGen studio provides a custom agent allows a human interact as part of the agent team.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogenstudio.components import UserProxyAgent\n",
|
||||
"\n",
|
||||
"def input_func(prompt: str) -> str:\n",
|
||||
" return \"Hello World there\" + str(prompt)\n",
|
||||
"user_agent = UserProxyAgent(name=\"user_agent\", description=\"a human user\", input_func=input_func)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"stream = user_agent.run_stream(task=\"hello there\", cancellation_token=cancellation_token)\n",
|
||||
"\n",
|
||||
"async for response in stream:\n",
|
||||
" print(response)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
@ -7,11 +7,16 @@ from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermi
|
|||
from autogen_core.tools import FunctionTool
|
||||
|
||||
from autogenstudio.datamodel.types import (
|
||||
AgentConfig,
|
||||
ModelConfig,
|
||||
TeamConfig,
|
||||
AssistantAgentConfig,
|
||||
OpenAIModelConfig,
|
||||
RoundRobinTeamConfig,
|
||||
SelectorTeamConfig,
|
||||
MagenticOneTeamConfig,
|
||||
ToolConfig,
|
||||
TerminationConfig,
|
||||
MaxMessageTerminationConfig,
|
||||
StopMessageTerminationConfig,
|
||||
TextMentionTerminationConfig,
|
||||
CombinationTerminationConfig,
|
||||
ModelTypes,
|
||||
AgentTypes,
|
||||
TeamTypes,
|
||||
|
@ -56,7 +61,7 @@ def calculator(a: int, b: int, operation: str = '+') -> int:
|
|||
|
||||
@pytest.fixture
|
||||
def sample_model_config():
|
||||
return ModelConfig(
|
||||
return OpenAIModelConfig(
|
||||
model_type=ModelTypes.OPENAI,
|
||||
model="gpt-4",
|
||||
api_key="test-key",
|
||||
|
@ -66,8 +71,8 @@ def sample_model_config():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_config(sample_model_config: ModelConfig, sample_tool_config: ToolConfig):
|
||||
return AgentConfig(
|
||||
def sample_agent_config(sample_model_config: OpenAIModelConfig, sample_tool_config: ToolConfig):
|
||||
return AssistantAgentConfig(
|
||||
name="test_agent",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
system_message="You are a helpful assistant",
|
||||
|
@ -80,7 +85,7 @@ def sample_agent_config(sample_model_config: ModelConfig, sample_tool_config: To
|
|||
|
||||
@pytest.fixture
|
||||
def sample_termination_config():
|
||||
return TerminationConfig(
|
||||
return MaxMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=10,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -90,9 +95,9 @@ def sample_termination_config():
|
|||
|
||||
@pytest.fixture
|
||||
def sample_team_config(
|
||||
sample_agent_config: AgentConfig, sample_termination_config: TerminationConfig, sample_model_config: ModelConfig
|
||||
sample_agent_config: AssistantAgentConfig, sample_termination_config: MaxMessageTerminationConfig, sample_model_config: OpenAIModelConfig
|
||||
):
|
||||
return TeamConfig(
|
||||
return RoundRobinTeamConfig(
|
||||
name="test_team",
|
||||
team_type=TeamTypes.ROUND_ROBIN,
|
||||
participants=[sample_agent_config],
|
||||
|
@ -146,14 +151,14 @@ async def test_load_tool_invalid_config(component_factory: ComponentFactory):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_model(component_factory: ComponentFactory, sample_model_config: ModelConfig):
|
||||
async def test_load_model(component_factory: ComponentFactory, sample_model_config: OpenAIModelConfig):
|
||||
# Test loading model from ModelConfig
|
||||
model = await component_factory.load_model(sample_model_config)
|
||||
assert model is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_agent(component_factory: ComponentFactory, sample_agent_config: AgentConfig):
|
||||
async def test_load_agent(component_factory: ComponentFactory, sample_agent_config: AssistantAgentConfig):
|
||||
# Test loading agent from AgentConfig
|
||||
agent = await component_factory.load_agent(sample_agent_config)
|
||||
assert isinstance(agent, AssistantAgent)
|
||||
|
@ -163,8 +168,8 @@ async def test_load_agent(component_factory: ComponentFactory, sample_agent_conf
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_termination(component_factory: ComponentFactory):
|
||||
# Test MaxMessageTermination
|
||||
max_msg_config = TerminationConfig(
|
||||
|
||||
max_msg_config = MaxMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -175,14 +180,14 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
assert termination._max_messages == 5
|
||||
|
||||
# Test StopMessageTermination
|
||||
stop_msg_config = TerminationConfig(
|
||||
stop_msg_config = StopMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.STOP_MESSAGE, component_type=ComponentTypes.TERMINATION, version="1.0.0"
|
||||
)
|
||||
termination = await component_factory.load_termination(stop_msg_config)
|
||||
assert isinstance(termination, StopMessageTermination)
|
||||
|
||||
# Test TextMentionTermination
|
||||
text_mention_config = TerminationConfig(
|
||||
text_mention_config = TextMentionTerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -193,17 +198,17 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
assert termination._text == "DONE"
|
||||
|
||||
# Test AND combination
|
||||
and_combo_config = TerminationConfig(
|
||||
and_combo_config = CombinationTerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="and",
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
MaxMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
),
|
||||
TerminationConfig(
|
||||
TextMentionTerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -217,17 +222,17 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
assert termination is not None
|
||||
|
||||
# Test OR combination
|
||||
or_combo_config = TerminationConfig(
|
||||
or_combo_config = CombinationTerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="or",
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
MaxMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
),
|
||||
TerminationConfig(
|
||||
TextMentionTerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -243,7 +248,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
# Test invalid combinations
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
CombinationTerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[], # Empty conditions
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -253,11 +258,11 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
CombinationTerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
operator="invalid", # type: ignore
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
MaxMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -272,16 +277,16 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
# Test missing operator
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
CombinationTerminationConfig(
|
||||
termination_type=TerminationTypes.COMBINATION,
|
||||
conditions=[
|
||||
TerminationConfig(
|
||||
MaxMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.MAX_MESSAGES,
|
||||
max_messages=5,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
),
|
||||
TerminationConfig(
|
||||
TextMentionTerminationConfig(
|
||||
termination_type=TerminationTypes.TEXT_MENTION,
|
||||
text="DONE",
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
|
@ -296,7 +301,7 @@ async def test_load_termination(component_factory: ComponentFactory):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_team(
|
||||
component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig
|
||||
component_factory: ComponentFactory, sample_team_config: RoundRobinTeamConfig, sample_model_config: OpenAIModelConfig
|
||||
):
|
||||
# Test loading RoundRobinGroupChat team
|
||||
team = await component_factory.load_team(sample_team_config)
|
||||
|
@ -304,12 +309,12 @@ async def test_load_team(
|
|||
assert len(team._participants) == 1
|
||||
|
||||
# Test loading SelectorGroupChat team with multiple participants
|
||||
selector_team_config = TeamConfig(
|
||||
selector_team_config = SelectorTeamConfig(
|
||||
name="selector_team",
|
||||
team_type=TeamTypes.SELECTOR,
|
||||
participants=[ # Add two participants
|
||||
sample_team_config.participants[0], # First agent
|
||||
AgentConfig( # Second agent
|
||||
AssistantAgentConfig( # Second agent
|
||||
name="test_agent_2",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
system_message="You are another helpful assistant",
|
||||
|
@ -329,12 +334,12 @@ async def test_load_team(
|
|||
assert len(team._participants) == 2
|
||||
|
||||
# Test loading MagenticOneGroupChat team
|
||||
magentic_one_config = TeamConfig(
|
||||
magentic_one_config = MagenticOneTeamConfig(
|
||||
name="magentic_one_team",
|
||||
team_type=TeamTypes.MAGENTIC_ONE,
|
||||
participants=[ # Add two participants
|
||||
sample_team_config.participants[0], # First agent
|
||||
AgentConfig( # Second agent
|
||||
AssistantAgentConfig( # Second agent
|
||||
name="test_agent_2",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
system_message="You are another helpful assistant",
|
||||
|
@ -360,7 +365,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
|
|||
# Test invalid agent type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_agent(
|
||||
AgentConfig(
|
||||
AssistantAgentConfig(
|
||||
name="test",
|
||||
agent_type="InvalidAgent", # type: ignore
|
||||
system_message="test",
|
||||
|
@ -372,7 +377,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
|
|||
# Test invalid team type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_team(
|
||||
TeamConfig(
|
||||
RoundRobinTeamConfig(
|
||||
name="test",
|
||||
team_type="InvalidTeam", # type: ignore
|
||||
participants=[],
|
||||
|
@ -384,7 +389,7 @@ async def test_invalid_configs(component_factory: ComponentFactory):
|
|||
# Test invalid termination type
|
||||
with pytest.raises(ValueError):
|
||||
await component_factory.load_termination(
|
||||
TerminationConfig(
|
||||
MaxMessageTerminationConfig(
|
||||
termination_type="InvalidTermination", # type: ignore
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0",
|
||||
|
|
|
@ -6,9 +6,13 @@ from typing import Generator
|
|||
|
||||
from autogenstudio.database import DatabaseManager
|
||||
from autogenstudio.datamodel.types import (
|
||||
ModelConfig, AgentConfig, ToolConfig,
|
||||
TeamConfig, ModelTypes, AgentTypes, TeamTypes, ComponentTypes,
|
||||
TerminationConfig, TerminationTypes, ToolTypes
|
||||
ToolConfig,
|
||||
OpenAIModelConfig,
|
||||
RoundRobinTeamConfig,
|
||||
StopMessageTerminationConfig,
|
||||
AssistantAgentConfig,
|
||||
ModelTypes, AgentTypes, TeamTypes, ComponentTypes,
|
||||
TerminationTypes, ToolTypes
|
||||
)
|
||||
from autogenstudio.datamodel.db import Model, Tool, Agent, Team, LinkTypes
|
||||
|
||||
|
@ -42,7 +46,7 @@ def sample_model(test_user: str) -> Model:
|
|||
"""Create a sample model with proper config"""
|
||||
return Model(
|
||||
user_id=test_user,
|
||||
config=ModelConfig(
|
||||
config=OpenAIModelConfig(
|
||||
model="gpt-4",
|
||||
model_type=ModelTypes.OPENAI,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
|
@ -72,10 +76,10 @@ def sample_agent(test_user: str, sample_model: Model, sample_tool: Tool) -> Agen
|
|||
"""Create a sample agent with proper config and relationships"""
|
||||
return Agent(
|
||||
user_id=test_user,
|
||||
config=AgentConfig(
|
||||
config=AssistantAgentConfig(
|
||||
name="test_agent",
|
||||
agent_type=AgentTypes.ASSISTANT,
|
||||
model_client=ModelConfig.model_validate(sample_model.config),
|
||||
model_client=OpenAIModelConfig.model_validate(sample_model.config),
|
||||
tools=[ToolConfig.model_validate(sample_tool.config)],
|
||||
component_type=ComponentTypes.AGENT,
|
||||
version="1.0.0"
|
||||
|
@ -88,10 +92,11 @@ def sample_team(test_user: str, sample_agent: Agent) -> Team:
|
|||
"""Create a sample team with proper config"""
|
||||
return Team(
|
||||
user_id=test_user,
|
||||
config=TeamConfig(
|
||||
config=RoundRobinTeamConfig(
|
||||
name="test_team",
|
||||
participants=[AgentConfig.model_validate(sample_agent.config)],
|
||||
termination_condition=TerminationConfig(
|
||||
participants=[AssistantAgentConfig.model_validate(
|
||||
sample_agent.config)],
|
||||
termination_condition=StopMessageTerminationConfig(
|
||||
termination_type=TerminationTypes.STOP_MESSAGE,
|
||||
component_type=ComponentTypes.TERMINATION,
|
||||
version="1.0.0"
|
||||
|
@ -142,7 +147,7 @@ class TestDatabaseOperations:
|
|||
# Create two models with updated configs
|
||||
model1 = Model(
|
||||
user_id="test_user",
|
||||
config=ModelConfig(
|
||||
config=OpenAIModelConfig(
|
||||
model="gpt-4",
|
||||
model_type=ModelTypes.OPENAI,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
|
@ -151,7 +156,7 @@ class TestDatabaseOperations:
|
|||
)
|
||||
model2 = Model(
|
||||
user_id="test_user",
|
||||
config=ModelConfig(
|
||||
config=OpenAIModelConfig(
|
||||
model="gpt-3.5",
|
||||
model_type=ModelTypes.OPENAI,
|
||||
component_type=ComponentTypes.MODEL,
|
||||
|
|
Loading…
Reference in New Issue