[Draft] Enable File Upload/Paste as Task in AGS (#6091)

<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?



https://github.com/user-attachments/assets/e160f16d-f42d-49e2-a6c6-687e4e6786f4



Enable file upload/paste as a task in AGS. Enables tasks like

- Can you research and fact check the ideas in this screenshot?
- Summarize this file

Only text and images supported for now
Underneath, it constructs TextMessage and Multimodal messages as the
task.

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

<!-- For example: "Closes #1234" -->

Closes #5773 

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed.

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
Victor Dibia 2025-04-08 19:44:45 -07:00 committed by GitHub
parent cc806a57ef
commit 32d2a18bf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 2643 additions and 550 deletions

View File

@ -8,7 +8,7 @@ from loguru import logger
from sqlalchemy import exc, inspect, text
from sqlmodel import Session, SQLModel, and_, create_engine, select
from ..datamodel import Response, Team
from ..datamodel import BaseDBModel, Response, Team
from ..teammanager import TeamManager
from .schema_manager import SchemaManager
@ -94,7 +94,7 @@ class DatabaseManager:
finally:
self._init_lock.release()
def reset_db(self, recreate_tables: bool = True):
def reset_db(self, recreate_tables: bool = True) -> Response:
"""
Reset the database by dropping all tables and optionally recreating them.
@ -151,7 +151,7 @@ class DatabaseManager:
self._init_lock.release()
logger.info("Database reset lock released")
def upsert(self, model: SQLModel, return_json: bool = True) -> Response:
def upsert(self, model: BaseDBModel, return_json: bool = True) -> Response:
"""Create or update an entity
Args:
@ -199,7 +199,7 @@ class DatabaseManager:
def get(
self,
model_class: SQLModel,
model_class: type[BaseDBModel],
filters: dict | None = None,
return_json: bool = False,
order: str = "desc",
@ -211,7 +211,7 @@ class DatabaseManager:
status_message = ""
try:
statement = select(model_class)
statement = select(model_class) # type: ignore
if filters:
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
statement = statement.where(and_(*conditions))
@ -231,7 +231,7 @@ class DatabaseManager:
return Response(message=status_message, status=status, data=result)
def delete(self, model_class: SQLModel, filters: dict = None) -> Response:
def delete(self, model_class: type[BaseDBModel], filters: dict | None = None) -> Response:
"""Delete an entity"""
status_message = ""
status = True
@ -239,8 +239,8 @@ class DatabaseManager:
with Session(self.engine) as session:
try:
if "sqlite" in str(self.engine.url):
session.exec(text("PRAGMA foreign_keys=ON"))
statement = select(model_class)
session.exec(text("PRAGMA foreign_keys=ON")) # type: ignore
statement = select(model_class) # type: ignore
if filters:
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
statement = statement.where(and_(*conditions))
@ -326,7 +326,7 @@ class DatabaseManager:
{
"status": result.status,
"message": result.message,
"id": result.data.get("id") if result.data else None,
"id": result.data.get("id") if result.data and result.data is not None else None,
}
)
@ -342,7 +342,8 @@ class DatabaseManager:
async def _check_team_exists(self, config: dict, user_id: str) -> Optional[Team]:
"""Check if identical team config already exists"""
teams = self.get(Team, {"user_id": user_id}).data
response = self.get(Team, {"user_id": user_id})
teams = response.data if response.status and response.data is not None else []
for team in teams:
if team.component == config:

View File

@ -1,4 +1,4 @@
from .db import Gallery, Message, Run, RunStatus, Session, Settings, Team
from .db import BaseDBModel, Gallery, Message, Run, RunStatus, Session, Settings, Team
from .types import (
EnvironmentVariable,
GalleryComponents,

View File

@ -2,13 +2,14 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from autogen_core import ComponentModel
from pydantic import ConfigDict, SecretStr
from sqlalchemy import ForeignKey, Integer, String
from pydantic import ConfigDict, SecretStr, field_validator
from sqlalchemy import ForeignKey, Integer
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func
from .eval import EvalJudgeCriteria, EvalRunResult, EvalRunStatus, EvalScore, EvalTask
from .types import (
GalleryComponents,
GalleryConfig,
@ -20,35 +21,41 @@ from .types import (
)
class Team(SQLModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
class BaseDBModel(SQLModel, table=False):
"""
Base model with common fields for all database tables.
Not a table itself - meant to be inherited by concrete model classes.
"""
__abstract__ = True
# Common fields present in all database tables
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
sa_type=DateTime(timezone=True), # type: ignore[assignment]
sa_column_kwargs={"server_default": func.now(), "nullable": True},
)
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
sa_type=DateTime(timezone=True), # type: ignore[assignment]
sa_column_kwargs={"onupdate": func.now(), "nullable": True},
)
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
class Team(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
class Message(SQLModel, table=True):
class Message(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[MessageConfig, dict] = Field(
default_factory=lambda: MessageConfig(source="", content=""), sa_column=Column(JSON)
)
@ -60,22 +67,18 @@ class Message(SQLModel, table=True):
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
class Session(SQLModel, table=True):
class Session(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
name: Optional[str] = None
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def parse_datetime(cls, value: Union[str, datetime]) -> datetime:
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
class RunStatus(str, Enum):
CREATED = "created"
@ -85,18 +88,11 @@ class RunStatus(str, Enum):
STOPPED = "stopped"
class Run(SQLModel, table=True):
class Run(BaseDBModel, table=True):
"""Represents a single execution run within a session"""
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
)
updated_at: datetime = Field(
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
)
session_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
)
@ -118,19 +114,9 @@ class Run(SQLModel, table=True):
user_id: Optional[str] = None
class Gallery(SQLModel, table=True):
class Gallery(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[GalleryConfig, dict] = Field(
default_factory=lambda: GalleryConfig(
id="",
@ -149,17 +135,64 @@ class Gallery(SQLModel, table=True):
) # type: ignore[call-arg]
class Settings(SQLModel, table=True):
class Settings(BaseDBModel, table=True):
__table_args__ = {"sqlite_autoincrement": True}
id: Optional[int] = Field(default=None, primary_key=True)
created_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
) # pylint: disable=not-callable
updated_at: datetime = Field(
default_factory=datetime.now,
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
) # pylint: disable=not-callable
user_id: Optional[str] = None
version: Optional[str] = "0.0.1"
config: Union[SettingsConfig, dict] = Field(default_factory=SettingsConfig, sa_column=Column(JSON))
# --- Evaluation system database models ---
class EvalTaskDB(BaseDBModel, table=True):
"""Database model for storing evaluation tasks."""
__table_args__ = {"sqlite_autoincrement": True}
name: str = "Unnamed Task"
description: str = ""
config: Union[EvalTask, dict] = Field(sa_column=Column(JSON))
class EvalCriteriaDB(BaseDBModel, table=True):
"""Database model for storing evaluation criteria."""
__table_args__ = {"sqlite_autoincrement": True}
name: str = "Unnamed Criteria"
description: str = ""
config: Union[EvalJudgeCriteria, dict] = Field(sa_column=Column(JSON))
class EvalRunDB(BaseDBModel, table=True):
"""Database model for tracking evaluation runs."""
__table_args__ = {"sqlite_autoincrement": True}
name: str = "Unnamed Evaluation Run"
description: str = ""
# References to related components
task_id: Optional[int] = Field(
default=None, sa_column=Column(Integer, ForeignKey("evaltaskdb.id", ondelete="SET NULL"))
)
# Serialized configurations for runner and judge
runner_config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
judge_config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
# List of criteria IDs or embedded criteria configs
criteria_configs: List[Union[EvalJudgeCriteria, dict]] = Field(default_factory=list, sa_column=Column(JSON))
# Run status and timing information
status: EvalRunStatus = Field(default=EvalRunStatus.PENDING)
start_time: Optional[datetime] = Field(default=None)
end_time: Optional[datetime] = Field(default=None)
# Results (updated as they become available)
run_result: Union[EvalRunResult, dict] = Field(default=None, sa_column=Column(JSON))
score_result: Union[EvalScore, dict] = Field(default=None, sa_column=Column(JSON))
# Additional metadata
error_message: Optional[str] = None

View File

@ -0,0 +1,82 @@
# datamodel/eval.py
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence
from uuid import UUID, uuid4
from autogen_agentchat.base import TaskResult
from autogen_core import Image
from pydantic import BaseModel
from sqlmodel import Field
class EvalTask(BaseModel):
"""Definition of a task to be evaluated."""
task_id: UUID | str = Field(default_factory=uuid4)
input: str | Sequence[str | Image]
name: str = ""
description: str = ""
expected_outputs: Optional[List[Any]] = None
metadata: Dict[str, Any] = {}
class EvalRunResult(BaseModel):
"""Result of an evaluation run."""
result: TaskResult | None = None
status: bool = False
start_time: Optional[datetime] = Field(default=datetime.now())
end_time: Optional[datetime] = None
error: Optional[str] = None
class EvalDimensionScore(BaseModel):
"""Score for a single evaluation dimension."""
dimension: str
score: float
reason: str
max_value: float
min_value: float
class EvalScore(BaseModel):
"""Composite score from evaluation."""
overall_score: Optional[float] = None
dimension_scores: List[EvalDimensionScore] = []
reason: Optional[str] = None
max_value: float = 10.0
min_value: float = 0.0
metadata: Dict[str, Any] = {}
class EvalJudgeCriteria(BaseModel):
"""Criteria for judging evaluation results."""
dimension: str
prompt: str
max_value: float = 10.0
min_value: float = 0.0
metadata: Dict[str, Any] = {}
class EvalRunStatus(str, Enum):
"""Status of an evaluation run."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELED = "canceled"
class EvalResult(BaseModel):
"""Result of an evaluation run."""
task_id: UUID | str
# runner_id: UUID | str
status: EvalRunStatus = EvalRunStatus.PENDING
start_time: Optional[datetime] = Field(default=datetime.now())
end_time: Optional[datetime] = None

View File

@ -1,9 +1,9 @@
# from dataclasses import Field
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Sequence
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import BaseChatMessage
from autogen_agentchat.messages import ChatMessage, TextMessage
from autogen_core import ComponentModel
from autogen_core.models import UserMessage
from autogen_ext.models.openai import OpenAIChatCompletionClient
@ -12,7 +12,7 @@ from pydantic import BaseModel, ConfigDict, SecretStr
class MessageConfig(BaseModel):
source: str
content: str
content: str | ChatMessage | Sequence[ChatMessage] | None
message_type: Optional[str] = "text"
@ -22,9 +22,8 @@ class TeamResult(BaseModel):
duration: float
class LLMCallEventMessage(BaseChatMessage):
class LLMCallEventMessage(TextMessage):
source: str = "llm_call_event"
content: str
def to_text(self) -> str:
return self.content

View File

@ -0,0 +1,267 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from autogen_core import CancellationToken, Component, ComponentBase
from autogen_core.models import ChatCompletionClient, UserMessage
from loguru import logger
from pydantic import BaseModel
from typing_extensions import Self
from ..datamodel.eval import EvalDimensionScore, EvalJudgeCriteria, EvalRunResult, EvalScore, EvalTask
class BaseEvalJudgeConfig(BaseModel):
"""Base configuration for evaluation judges."""
name: str = "Base Judge"
description: str = ""
metadata: Dict[str, Any] = {}
class BaseEvalJudge(ABC, ComponentBase[BaseEvalJudgeConfig]):
"""Abstract base class for evaluation judges."""
component_type = "eval_judge"
def __init__(self, name: str = "Base Judge", description: str = "", metadata: Optional[Dict[str, Any]] = None):
self.name = name
self.description = description
self.metadata = metadata or {}
@abstractmethod
async def judge(
self,
task: EvalTask,
result: EvalRunResult,
criteria: List[EvalJudgeCriteria],
cancellation_token: Optional[CancellationToken] = None,
) -> EvalScore:
"""Judge the result of an evaluation run."""
pass
def _to_config(self) -> BaseEvalJudgeConfig:
"""Convert the judge configuration to a configuration object for serialization."""
return BaseEvalJudgeConfig(name=self.name, description=self.description, metadata=self.metadata)
class LLMEvalJudgeConfig(BaseEvalJudgeConfig):
"""Configuration for LLMEvalJudge."""
model_client: Any # ComponentModel
class LLMEvalJudge(BaseEvalJudge, Component[LLMEvalJudgeConfig]):
"""Judge that uses an LLM to evaluate results."""
component_config_schema = LLMEvalJudgeConfig
component_type = "eval_judge"
component_provider_override = "autogenstudio.eval.judges.LLMEvalJudge"
def __init__(
self,
model_client: ChatCompletionClient,
name: str = "LLM Judge",
description: str = "Evaluates results using an LLM",
metadata: Optional[Dict[str, Any]] = None,
):
super().__init__(name, description, metadata)
self.model_client = model_client
async def judge(
self,
task: EvalTask,
result: EvalRunResult,
criteria: List[EvalJudgeCriteria],
cancellation_token: Optional[CancellationToken] = None,
) -> EvalScore:
"""Judge the result using an LLM."""
# Create a score object
score = EvalScore(max_value=10.0)
# Judge each dimension in parallel
dimension_score_tasks = []
for criterion in criteria:
dimension_score_tasks.append(self._judge_dimension(task, result, criterion, cancellation_token))
dimension_scores = await asyncio.gather(*dimension_score_tasks)
score.dimension_scores = dimension_scores
# Calculate overall score (average of dimension scores)
valid_scores = [ds.score for ds in dimension_scores if ds.score is not None]
if valid_scores:
score.overall_score = sum(valid_scores) / len(valid_scores)
return score
async def _judge_dimension(
self,
task: EvalTask,
result: EvalRunResult,
criterion: EvalJudgeCriteria,
cancellation_token: Optional[CancellationToken] = None,
) -> EvalDimensionScore:
"""Judge a specific dimension."""
# Format task and result for the LLM
task_description = self._format_task(task)
result_description = result.model_dump()
# Create the prompt
prompt = f"""
You are evaluating the quality of a system response to a task.
Task: {task_description}Response: {result_description}
Evaluation criteria: {criterion.dimension}
{criterion.prompt}
Score the response on a scale from {criterion.min_value} to {criterion.max_value}.
First, provide a detailed explanation of your evaluation.
Then, give your final score as a single number between 0 and {criterion.max_value}.
Format your answer should be a json for the EvalDimensionScore class:
{{
"dimension": "{criterion.dimension}",
"reason": "<explanation>",
"score": <score>
}}
Please ensure the score is a number between {criterion.min_value} and {criterion.max_value}.
If you cannot evaluate the response, please return a score of null.
If the response is not relevant, please return a score of 0.
If the response is perfect, please return a score of {criterion.max_value}.
If the response is not relevant, please return a score of 0.
If the response is perfect, please return a score of {criterion.max_value}.
"""
# Get judgment from LLM
model_input = []
text_message = UserMessage(content=prompt, source="user")
model_input.append(text_message)
# Run with the model client in the same format as used in runners
model_result = await self.model_client.create(
messages=model_input,
cancellation_token=cancellation_token,
json_output=EvalDimensionScore,
)
# Extract content from the response
model_response = model_result.content if isinstance(model_result.content, str) else str(model_result.content)
try:
# validate response string as EvalDimensionScore
model_response = EvalDimensionScore.model_validate_json(model_response)
return model_response
except Exception as e:
logger.warning(f"Failed to parse LLM response: {e}", model_result.content)
return EvalDimensionScore(
dimension=criterion.dimension,
reason="Failed to parse response",
score=0.0,
max_value=criterion.max_value,
min_value=criterion.min_value,
)
def _format_task(self, task: EvalTask) -> str:
"""Format the task for the LLM."""
task_parts = []
if task.description:
task_parts.append(task.description)
if isinstance(task.input, str):
task_parts.append(task.input)
elif isinstance(task.input, list):
task_parts.append("\n".join(str(x) for x in task.input if isinstance(x, str)))
return "\n".join(task_parts)
def _parse_judgment(self, judgment_text: str, max_value: float) -> Tuple[str, Optional[float]]:
"""Parse judgment text to extract explanation and score."""
explanation = ""
score = None
# Simple parsing - could be improved with regex
lines = judgment_text.split("\n")
for line in lines:
if line.strip().lower().startswith("explanation:"):
explanation = line.split(":", 1)[1].strip()
elif line.strip().lower().startswith("score:"):
try:
score_str = line.split(":", 1)[1].strip()
score = float(score_str)
# Ensure score is within bounds
score = min(max(score, 0), max_value)
except (ValueError, IndexError):
pass
return explanation, score
def _to_config(self) -> LLMEvalJudgeConfig:
"""Convert to configuration object including model client configuration."""
base_config = super()._to_config()
return LLMEvalJudgeConfig(
name=base_config.name,
description=base_config.description,
metadata=base_config.metadata,
model_client=self.model_client.dump_component(),
)
@classmethod
def _from_config(cls, config: LLMEvalJudgeConfig) -> Self:
"""Create from configuration object with serialized model client."""
model_client = ChatCompletionClient.load_component(config.model_client)
return cls(
model_client=model_client, name=config.name, description=config.description, metadata=config.metadata
)
# # Usage example
# async def example_usage():
# # Create a model client
# from autogen_ext.models import OpenAIChatCompletionClient
# model_client = OpenAIChatCompletionClient(
# model="gpt-4",
# api_key="your-api-key"
# )
# # Create a judge
# llm_judge = LLMEvalJudge(model_client=model_client)
# # Serialize the judge to a ComponentModel
# judge_config = llm_judge.dump_component()
# print(f"Serialized judge: {judge_config}")
# # Deserialize back to a LLMEvalJudge
# deserialized_judge = LLMEvalJudge.load_component(judge_config)
# # Create criteria for evaluation
# criteria = [
# EvalJudgeCriteria(
# dimension="relevance",
# prompt="Evaluate how relevant the response is to the query.",
# min_value=0,
# max_value=10
# ),
# EvalJudgeCriteria(
# dimension="accuracy",
# prompt="Evaluate the factual accuracy of the response.",
# min_value=0,
# max_value=10
# )
# ]
# # Create a mock task and result
# task = EvalTask(
# id="task-123",
# name="Sample Task",
# description="A sample task for evaluation",
# input="What is the capital of France?"
# )
# result = EvalRunResult(
# status=True,
# result={
# "messages": [{"content": "The capital of France is Paris.", "source": "model"}]
# }
# )
# # Run the evaluation
# score = await deserialized_judge.judge(task, result, criteria)
# print(f"Evaluation score: {score}")

View File

@ -0,0 +1,789 @@
import asyncio
import uuid
from datetime import datetime
from pdb import run
from typing import Any, Dict, List, Optional, TypedDict, Union
from loguru import logger
from pydantic import BaseModel
from ..database.db_manager import DatabaseManager
from ..datamodel.db import EvalCriteriaDB, EvalRunDB, EvalTaskDB
from ..datamodel.eval import EvalJudgeCriteria, EvalRunResult, EvalRunStatus, EvalScore, EvalTask
from .judges import BaseEvalJudge
from .runners import BaseEvalRunner
class DimensionScore(TypedDict):
score: Optional[float]
reason: Optional[str]
class RunEntry(TypedDict):
id: str
name: str
task_name: str
runner_type: str
overall_score: Optional[float]
scores: List[Optional[float]]
reasons: Optional[List[Optional[str]]]
class TabulatedResults(TypedDict):
dimensions: List[str]
runs: List[RunEntry]
class EvalOrchestrator:
"""
Orchestrator for evaluation runs.
This class manages the lifecycle of evaluation tasks, criteria, and runs.
It can operate with or without a database manager for persistence.
"""
def __init__(self, db_manager: Optional[DatabaseManager] = None):
"""
Initialize the orchestrator.
Args:
db_manager: Optional database manager for persistence.
If None, data is stored in memory only.
"""
self._db_manager = db_manager
# In-memory storage (used when db_manager is None)
self._tasks: Dict[str, EvalTask] = {}
self._criteria: Dict[str, EvalJudgeCriteria] = {}
self._runs: Dict[str, Dict[str, Any]] = {}
# Active runs tracking
self._active_runs: Dict[str, asyncio.Task] = {}
# ----- Task Management -----
async def create_task(self, task: EvalTask) -> str:
"""
Create a new evaluation task.
Args:
task: The evaluation task to create
Returns:
Task ID
"""
if not task.task_id:
task.task_id = str(uuid.uuid4())
if self._db_manager:
# Store in database
task_db = EvalTaskDB(name=task.name, description=task.description, config=task)
response = self._db_manager.upsert(task_db)
if not response.status:
logger.error(f"Failed to store task: {response.message}")
raise RuntimeError(f"Failed to store task: {response.message}")
task_id = str(response.data.get("id")) if response.data else str(task.task_id)
else:
# Store in memory
task_id = str(task.task_id)
self._tasks[task_id] = task
return task_id
async def get_task(self, task_id: str) -> Optional[EvalTask]:
"""
Retrieve an evaluation task by ID.
Args:
task_id: The ID of the task to retrieve
Returns:
The task if found, None otherwise
"""
if self._db_manager:
# Retrieve from database
response = self._db_manager.get(EvalTaskDB, filters={"id": int(task_id) if task_id.isdigit() else task_id})
if response.status and response.data and len(response.data) > 0:
task_data = response.data[0]
return (
task_data.get("config")
if isinstance(task_data.get("config"), EvalTask)
else EvalTask.model_validate(task_data.get("config"))
)
else:
# Retrieve from memory
return self._tasks.get(task_id)
return None
async def list_tasks(self) -> List[EvalTask]:
"""
List all available evaluation tasks.
Returns:
List of evaluation tasks
"""
if self._db_manager:
# Retrieve from database
response = self._db_manager.get(EvalTaskDB)
tasks = []
if response.status and response.data:
for task_data in response.data:
config = task_data.get("config")
if config:
if isinstance(config, EvalTask):
tasks.append(config)
else:
tasks.append(EvalTask.model_validate(config))
return tasks
else:
# Retrieve from memory
return list(self._tasks.values())
# ----- Criteria Management -----
async def create_criteria(self, criteria: EvalJudgeCriteria) -> str:
"""
Create new evaluation criteria.
Args:
criteria: The evaluation criteria to create
Returns:
Criteria ID
"""
criteria_id = str(uuid.uuid4())
if self._db_manager:
# Store in database
criteria_db = EvalCriteriaDB(name=criteria.dimension, description=criteria.prompt, config=criteria)
response = self._db_manager.upsert(criteria_db)
if not response.status:
logger.error(f"Failed to store criteria: {response.message}")
raise RuntimeError(f"Failed to store criteria: {response.message}")
criteria_id = str(response.data.get("id")) if response.data else criteria_id
else:
# Store in memory
self._criteria[criteria_id] = criteria
return criteria_id
async def get_criteria(self, criteria_id: str) -> Optional[EvalJudgeCriteria]:
"""
Retrieve evaluation criteria by ID.
Args:
criteria_id: The ID of the criteria to retrieve
Returns:
The criteria if found, None otherwise
"""
if self._db_manager:
# Retrieve from database
response = self._db_manager.get(
EvalCriteriaDB, filters={"id": int(criteria_id) if criteria_id.isdigit() else criteria_id}
)
if response.status and response.data and len(response.data) > 0:
criteria_data = response.data[0]
return (
criteria_data.get("config")
if isinstance(criteria_data.get("config"), EvalJudgeCriteria)
else EvalJudgeCriteria.model_validate(criteria_data.get("config"))
)
else:
# Retrieve from memory
return self._criteria.get(criteria_id)
return None
async def list_criteria(self) -> List[EvalJudgeCriteria]:
"""
List all available evaluation criteria.
Returns:
List of evaluation criteria
"""
if self._db_manager:
# Retrieve from database
response = self._db_manager.get(EvalCriteriaDB)
criteria_list = []
if response.status and response.data:
for criteria_data in response.data:
config = criteria_data.get("config")
if config:
if isinstance(config, EvalJudgeCriteria):
criteria_list.append(config)
else:
criteria_list.append(EvalJudgeCriteria.model_validate(config))
return criteria_list
else:
# Retrieve from memory
return list(self._criteria.values())
# ----- Run Management -----
async def create_run(
self,
task: Union[str, EvalTask],
runner: BaseEvalRunner,
judge: BaseEvalJudge,
criteria: List[Union[str, EvalJudgeCriteria]],
name: str = "",
description: str = "",
) -> str:
"""
Create a new evaluation run configuration.
Args:
task: The task to evaluate (ID or task object)
runner: The runner to use for evaluation
judge: The judge to use for evaluation
criteria: List of criteria to use for evaluation (IDs or criteria objects)
name: Name for the run
description: Description for the run
Returns:
Run ID
"""
# Resolve task
task_obj = None
if isinstance(task, str):
task_obj = await self.get_task(task)
if not task_obj:
raise ValueError(f"Task not found: {task}")
else:
task_obj = task
# Resolve criteria
criteria_objs = []
for criterion in criteria:
if isinstance(criterion, str):
criterion_obj = await self.get_criteria(criterion)
if not criterion_obj:
raise ValueError(f"Criteria not found: {criterion}")
criteria_objs.append(criterion_obj)
else:
criteria_objs.append(criterion)
# Generate run ID
run_id = str(uuid.uuid4())
# Create run configuration
runner_config = runner.dump_component() if hasattr(runner, "dump_component") else runner._to_config()
judge_config = judge.dump_component() if hasattr(judge, "dump_component") else judge._to_config()
if self._db_manager:
# Store in database
run_db = EvalRunDB(
name=name or f"Run {run_id}",
description=description,
task_id=int(task) if isinstance(task, str) and task.isdigit() else None,
runner_config=runner_config.model_dump(),
judge_config=judge_config.model_dump(),
criteria_configs=criteria_objs,
status=EvalRunStatus.PENDING,
)
response = self._db_manager.upsert(run_db)
if not response.status:
logger.error(f"Failed to store run: {response.message}")
raise RuntimeError(f"Failed to store run: {response.message}")
run_id = str(response.data.get("id")) if response.data else run_id
else:
# Store in memory
self._runs[run_id] = {
"task": task_obj,
"runner_config": runner_config,
"judge_config": judge_config,
"criteria_configs": [c.model_dump() for c in criteria_objs],
"status": EvalRunStatus.PENDING,
"created_at": datetime.now(),
"run_result": None,
"score_result": None,
"name": name or f"Run {run_id}",
"description": description,
}
return run_id
async def start_run(self, run_id: str) -> None:
"""
Start an evaluation run.
Args:
run_id: The ID of the run to start
"""
# Check if run is already active
if run_id in self._active_runs:
logger.warning(f"Run {run_id} is already active")
return
# Start the run asynchronously
run_task = asyncio.create_task(self._execute_run(run_id))
self._active_runs[run_id] = run_task
# Update run status
await self._update_run_status(run_id, EvalRunStatus.RUNNING)
async def _execute_run(self, run_id: str) -> None:
"""
Execute an evaluation run.
Args:
run_id: The ID of the run to execute
"""
try:
# Get run configuration
run_config = await self._get_run_config(run_id)
if not run_config:
raise ValueError(f"Run not found: {run_id}")
# Get task
task = run_config.get("task")
if not task:
raise ValueError(f"Task not found for run: {run_id}")
# Initialize runner
runner_config = run_config.get("runner_config")
runner = BaseEvalRunner.load_component(runner_config) if runner_config else None
# Initialize judge
judge_config = run_config.get("judge_config")
judge = BaseEvalJudge.load_component(judge_config) if judge_config else None
if not runner or not judge:
raise ValueError(f"Runner or judge not found for run: {run_id}")
# Initialize criteria
criteria_configs = run_config.get("criteria_configs")
criteria = []
if criteria_configs:
criteria = [
EvalJudgeCriteria.model_validate(c) if not isinstance(c, EvalJudgeCriteria) else c
for c in criteria_configs
]
# Execute runner
logger.info(f"Starting runner for run {run_id}")
start_time = datetime.now()
run_result = await runner.run(task)
# Update run result
await self._update_run_result(run_id, run_result)
if not run_result.status:
logger.error(f"Runner failed for run {run_id}: {run_result.error}")
await self._update_run_status(run_id, EvalRunStatus.FAILED)
return
# Execute judge
logger.info(f"Starting judge for run {run_id}")
score_result = await judge.judge(task, run_result, criteria)
# Update score result
await self._update_score_result(run_id, score_result)
# Update run status
end_time = datetime.now()
await self._update_run_completed(run_id, start_time, end_time)
logger.info(f"Run {run_id} completed successfully")
except Exception as e:
logger.exception(f"Error executing run {run_id}: {str(e)}")
await self._update_run_error(run_id, str(e))
finally:
# Remove from active runs
if run_id in self._active_runs:
del self._active_runs[run_id]
async def get_run_status(self, run_id: str) -> Optional[EvalRunStatus]:
"""
Get the status of an evaluation run.
Args:
run_id: The ID of the run
Returns:
The run status if found, None otherwise
"""
run_config = await self._get_run_config(run_id)
return run_config.get("status") if run_config else None
async def get_run_result(self, run_id: str) -> Optional[EvalRunResult]:
"""
Get the result of an evaluation run.
Args:
run_id: The ID of the run
Returns:
The run result if found, None otherwise
"""
run_config = await self._get_run_config(run_id)
if not run_config:
return None
run_result = run_config.get("run_result")
if not run_result:
return None
return run_result if isinstance(run_result, EvalRunResult) else EvalRunResult.model_validate(run_result)
async def get_run_score(self, run_id: str) -> Optional[EvalScore]:
"""
Get the score of an evaluation run.
Args:
run_id: The ID of the run
Returns:
The run score if found, None otherwise
"""
run_config = await self._get_run_config(run_id)
if not run_config:
return None
score_result = run_config.get("score_result")
if not score_result:
return None
return score_result if isinstance(score_result, EvalScore) else EvalScore.model_validate(score_result)
async def list_runs(self) -> List[Dict[str, Any]]:
"""
List all available evaluation runs.
Returns:
List of run configurations
"""
if self._db_manager:
# Retrieve from database
response = self._db_manager.get(EvalRunDB)
runs = []
if response.status and response.data:
for run_data in response.data:
runs.append(
{
"id": run_data.get("id"),
"name": run_data.get("name"),
"status": run_data.get("status"),
"created_at": run_data.get("created_at"),
"updated_at": run_data.get("updated_at"),
}
)
return runs
else:
# Retrieve from memory
return [
{
"id": run_id,
"name": run_config.get("name"),
"status": run_config.get("status"),
"created_at": run_config.get("created_at"),
"updated_at": run_config.get("updated_at", run_config.get("created_at")),
}
for run_id, run_config in self._runs.items()
]
async def cancel_run(self, run_id: str) -> bool:
"""
Cancel an active evaluation run.
Args:
run_id: The ID of the run to cancel
Returns:
True if the run was cancelled, False otherwise
"""
# Check if run is active
if run_id not in self._active_runs:
logger.warning(f"Run {run_id} is not active")
return False
# Cancel the run task
try:
self._active_runs[run_id].cancel()
await self._update_run_status(run_id, EvalRunStatus.CANCELED)
del self._active_runs[run_id]
return True
except Exception as e:
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
return False
# ----- Helper Methods -----
async def _get_run_config(self, run_id: str) -> Optional[Dict[str, Any]]:
"""
Get the configuration of an evaluation run.
Args:
run_id: The ID of the run
Returns:
The run configuration if found, None otherwise
"""
if self._db_manager:
# Retrieve from database
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
if response.status and response.data and len(response.data) > 0:
run_data = response.data[0]
# Get task
task = None
if run_data.get("task_id"):
task_response = self._db_manager.get(EvalTaskDB, filters={"id": run_data.get("task_id")})
if task_response.status and task_response.data and len(task_response.data) > 0:
task_data = task_response.data[0]
task = (
task_data.get("config")
if isinstance(task_data.get("config"), EvalTask)
else EvalTask.model_validate(task_data.get("config"))
)
return {
"task": task,
"runner_config": run_data.get("runner_config"),
"judge_config": run_data.get("judge_config"),
"criteria_configs": run_data.get("criteria_configs"),
"status": run_data.get("status"),
"run_result": run_data.get("run_result"),
"score_result": run_data.get("score_result"),
"name": run_data.get("name"),
"description": run_data.get("description"),
"created_at": run_data.get("created_at"),
"updated_at": run_data.get("updated_at"),
}
else:
# Retrieve from memory
return self._runs.get(run_id)
return None
async def _update_run_status(self, run_id: str, status: EvalRunStatus) -> None:
"""
Update the status of an evaluation run.
Args:
run_id: The ID of the run
status: The new status
"""
if self._db_manager:
# Update in database
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
if response.status and response.data and len(response.data) > 0:
run_data = response.data[0]
run_db = EvalRunDB.model_validate(run_data)
run_db.status = status
run_db.updated_at = datetime.now()
self._db_manager.upsert(run_db)
else:
# Update in memory
if run_id in self._runs:
self._runs[run_id]["status"] = status
self._runs[run_id]["updated_at"] = datetime.now()
async def _update_run_result(self, run_id: str, run_result: EvalRunResult) -> None:
"""
Update the result of an evaluation run.
Args:
run_id: The ID of the run
run_result: The run result
"""
if self._db_manager:
# Update in database
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
if response.status and response.data and len(response.data) > 0:
run_data = response.data[0]
run_db = EvalRunDB.model_validate(run_data)
run_db.run_result = run_result
run_db.updated_at = datetime.now()
self._db_manager.upsert(run_db)
else:
# Update in memory
if run_id in self._runs:
self._runs[run_id]["run_result"] = run_result
self._runs[run_id]["updated_at"] = datetime.now()
async def _update_score_result(self, run_id: str, score_result: EvalScore) -> None:
"""
Update the score of an evaluation run.
Args:
run_id: The ID of the run
score_result: The score result
"""
if self._db_manager:
# Update in database
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
if response.status and response.data and len(response.data) > 0:
run_data = response.data[0]
run_db = EvalRunDB.model_validate(run_data)
run_db.score_result = score_result
run_db.updated_at = datetime.now()
self._db_manager.upsert(run_db)
else:
# Update in memory
if run_id in self._runs:
self._runs[run_id]["score_result"] = score_result
self._runs[run_id]["updated_at"] = datetime.now()
async def _update_run_completed(self, run_id: str, start_time: datetime, end_time: datetime) -> None:
"""
Update a run as completed.
Args:
run_id: The ID of the run
start_time: The start time
end_time: The end time
"""
if self._db_manager:
# Update in database
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
if response.status and response.data and len(response.data) > 0:
run_data = response.data[0]
run_db = EvalRunDB.model_validate(run_data)
run_db.status = EvalRunStatus.COMPLETED
run_db.start_time = start_time
run_db.end_time = end_time
run_db.updated_at = datetime.now()
self._db_manager.upsert(run_db)
else:
# Update in memory
if run_id in self._runs:
self._runs[run_id]["status"] = EvalRunStatus.COMPLETED
self._runs[run_id]["start_time"] = start_time
self._runs[run_id]["end_time"] = end_time
self._runs[run_id]["updated_at"] = datetime.now()
async def _update_run_error(self, run_id: str, error_message: str) -> None:
"""
Update a run with an error.
Args:
run_id: The ID of the run
error_message: The error message
"""
if self._db_manager:
# Update in database
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
if response.status and response.data and len(response.data) > 0:
run_data = response.data[0]
run_db = EvalRunDB.model_validate(run_data)
run_db.status = EvalRunStatus.FAILED
run_db.error_message = error_message
run_db.end_time = datetime.now()
run_db.updated_at = datetime.now()
self._db_manager.upsert(run_db)
else:
# Update in memory
if run_id in self._runs:
self._runs[run_id]["status"] = EvalRunStatus.FAILED
self._runs[run_id]["error_message"] = error_message
self._runs[run_id]["end_time"] = datetime.now()
self._runs[run_id]["updated_at"] = datetime.now()
async def tabulate_results(self, run_ids: List[str], include_reasons: bool = False) -> TabulatedResults:
"""
Generate a tabular representation of evaluation results across runs.
This method collects scores across different runs and organizes them by
dimension, making it easy to create visualizations like radar charts.
Args:
run_ids: List of run IDs to include in the tabulation
include_reasons: Whether to include scoring reasons in the output
Returns:
A dictionary with structured data suitable for visualization
"""
result: TabulatedResults = {"dimensions": [], "runs": []}
# Parallelize fetching of run configs and scores
fetch_tasks = []
for run_id in run_ids:
fetch_tasks.append(self._get_run_config(run_id))
fetch_tasks.append(self.get_run_score(run_id))
# Wait for all fetches to complete
fetch_results = await asyncio.gather(*fetch_tasks)
# Process fetched data
dimensions_set = set()
run_data = {}
for i in range(0, len(fetch_results), 2):
run_id = run_ids[i // 2]
run_config = fetch_results[i]
score = fetch_results[i + 1]
# Store run data for later processing
run_data[run_id] = (run_config, score)
# Collect dimensions
if score and score.dimension_scores:
for dim_score in score.dimension_scores:
dimensions_set.add(dim_score.dimension)
# Convert dimensions to sorted list
result["dimensions"] = sorted(list(dimensions_set))
# Process each run's data
for run_id, (run_config, score) in run_data.items():
if not run_config or not score:
continue
# Determine runner type
runner_type = "unknown"
if run_config.get("runner_config"):
runner_config = run_config.get("runner_config")
if runner_config is not None and "provider" in runner_config:
if "ModelEvalRunner" in runner_config["provider"]:
runner_type = "model"
elif "TeamEvalRunner" in runner_config["provider"]:
runner_type = "team"
# Get task name
task = run_config.get("task")
task_name = task.name if task else "Unknown Task"
# Create run entry
run_entry: RunEntry = {
"id": run_id,
"name": run_config.get("name", f"Run {run_id}"),
"task_name": task_name,
"runner_type": runner_type,
"overall_score": score.overall_score,
"scores": [],
"reasons": [] if include_reasons else None,
}
# Build dimension lookup map for O(1) access
dim_map = {ds.dimension: ds for ds in score.dimension_scores}
# Populate scores aligned with dimensions
for dim in result["dimensions"]:
dim_score = dim_map.get(dim)
if dim_score:
run_entry["scores"].append(dim_score.score)
if include_reasons:
run_entry["reasons"].append(dim_score.reason) # type: ignore
else:
run_entry["scores"].append(None)
if include_reasons:
run_entry["reasons"].append(None) # type: ignore
result["runs"].append(run_entry)
return result

View File

@ -0,0 +1,201 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, Optional, Sequence, Type, Union
from autogen_agentchat.base import TaskResult, Team
from autogen_agentchat.messages import ChatMessage, MultiModalMessage, TextMessage
from autogen_core import CancellationToken, Component, ComponentBase, ComponentModel, Image
from autogen_core.models import ChatCompletionClient, UserMessage
from pydantic import BaseModel
from typing_extensions import Self
from ..datamodel.eval import EvalRunResult, EvalTask
class BaseEvalRunnerConfig(BaseModel):
"""Base configuration for evaluation runners."""
name: str
description: str = ""
metadata: Dict[str, Any] = {}
class BaseEvalRunner(ABC, ComponentBase[BaseEvalRunnerConfig]):
"""Base class for evaluation runners that defines the interface for running evaluations.
This class provides the core interface that all evaluation runners must implement.
Subclasses should implement the run method to define how a specific evaluation is executed.
"""
component_type = "eval_runner"
def __init__(self, name: str, description: str = "", metadata: Optional[Dict[str, Any]] = None):
self.name = name
self.description = description
self.metadata = metadata or {}
@abstractmethod
async def run(self, task: EvalTask, cancellation_token: Optional[CancellationToken] = None) -> EvalRunResult:
"""Run the evaluation on the provided task and return a result.
Args:
task: The task to evaluate
cancellation_token: Optional token to cancel the evaluation
Returns:
EvaluationResult: The result of the evaluation
"""
pass
def _to_config(self) -> BaseEvalRunnerConfig:
"""Convert the runner configuration to a configuration object for serialization."""
return BaseEvalRunnerConfig(name=self.name, description=self.description, metadata=self.metadata)
class ModelEvalRunnerConfig(BaseEvalRunnerConfig):
"""Configuration for ModelEvalRunner."""
model_client: ComponentModel
class ModelEvalRunner(BaseEvalRunner, Component[ModelEvalRunnerConfig]):
"""Evaluation runner that uses a single LLM to process tasks.
This runner sends the task directly to a model client and returns the response.
"""
component_config_schema = ModelEvalRunnerConfig
component_type = "eval_runner"
component_provider_override = "autogenstudio.eval.runners.ModelEvalRunner"
def __init__(
self,
model_client: ChatCompletionClient,
name: str = "Model Runner",
description: str = "Evaluates tasks using a single LLM",
metadata: Optional[Dict[str, Any]] = None,
):
super().__init__(name, description, metadata)
self.model_client = model_client
async def run(self, task: EvalTask, cancellation_token: Optional[CancellationToken] = None) -> EvalRunResult:
"""Run the task with the model client and return the result."""
# Create initial result object
result = EvalRunResult()
try:
model_input = []
if isinstance(task.input, str):
text_message = UserMessage(content=task.input, source="user")
model_input.append(text_message)
elif isinstance(task.input, list):
message_content = [x for x in task.input]
model_input.append(UserMessage(content=message_content, source="user"))
# Run with the model
model_result = await self.model_client.create(messages=model_input, cancellation_token=cancellation_token)
model_response = model_result.content if isinstance(model_result, str) else model_result.model_dump()
task_result = TaskResult(
messages=[TextMessage(content=str(model_response), source="model")],
)
result = EvalRunResult(result=task_result, status=True, start_time=datetime.now(), end_time=datetime.now())
except Exception as e:
result = EvalRunResult(status=False, error=str(e), end_time=datetime.now())
return result
def _to_config(self) -> ModelEvalRunnerConfig:
"""Convert to configuration object including model client configuration."""
base_config = super()._to_config()
return ModelEvalRunnerConfig(
name=base_config.name,
description=base_config.description,
metadata=base_config.metadata,
model_client=self.model_client.dump_component(),
)
@classmethod
def _from_config(cls, config: ModelEvalRunnerConfig) -> Self:
"""Create from configuration object with serialized model client."""
model_client = ChatCompletionClient.load_component(config.model_client)
return cls(
name=config.name,
description=config.description,
metadata=config.metadata,
model_client=model_client,
)
class TeamEvalRunnerConfig(BaseEvalRunnerConfig):
"""Configuration for TeamEvalRunner."""
team: ComponentModel
class TeamEvalRunner(BaseEvalRunner, Component[TeamEvalRunnerConfig]):
"""Evaluation runner that uses a team of agents to process tasks.
This runner creates and runs a team based on a team configuration.
"""
component_config_schema = TeamEvalRunnerConfig
component_type = "eval_runner"
component_provider_override = "autogenstudio.eval.runners.TeamEvalRunner"
def __init__(
self,
team: Union[Team, ComponentModel],
name: str = "Team Runner",
description: str = "Evaluates tasks using a team of agents",
metadata: Optional[Dict[str, Any]] = None,
):
super().__init__(name, description, metadata)
self._team = team if isinstance(team, Team) else Team.load_component(team)
async def run(self, task: EvalTask, cancellation_token: Optional[CancellationToken] = None) -> EvalRunResult:
"""Run the task with the team and return the result."""
# Create initial result object
result = EvalRunResult()
try:
team_task: Sequence[ChatMessage] = []
if isinstance(task.input, str):
team_task.append(TextMessage(content=task.input, source="user"))
if isinstance(task.input, list):
for message in task.input:
if isinstance(message, str):
team_task.append(TextMessage(content=message, source="user"))
elif isinstance(message, Image):
team_task.append(MultiModalMessage(source="user", content=[message]))
# Run task with team
team_result = await self._team.run(task=team_task, cancellation_token=cancellation_token)
result = EvalRunResult(result=team_result, status=True, start_time=datetime.now(), end_time=datetime.now())
except Exception as e:
result = EvalRunResult(status=False, error=str(e), end_time=datetime.now())
return result
def _to_config(self) -> TeamEvalRunnerConfig:
"""Convert to configuration object including team configuration."""
base_config = super()._to_config()
return TeamEvalRunnerConfig(
name=base_config.name,
description=base_config.description,
metadata=base_config.metadata,
team=self._team.dump_component(),
)
@classmethod
def _from_config(cls, config: TeamEvalRunnerConfig) -> Self:
"""Create from configuration object with serialized team configuration."""
return cls(
team=Team.load_component(config.team),
name=config.name,
description=config.description,
metadata=config.metadata,
)

View File

@ -4,18 +4,19 @@ import logging
import os
import time
from pathlib import Path
from typing import AsyncGenerator, Callable, List, Optional, Union
from typing import AsyncGenerator, Callable, List, Optional, Sequence, Union
import aiofiles
import yaml
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import TaskResult, Team
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_agentchat.teams import BaseGroupChat
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, ComponentModel
from autogen_core.logging import LLMCallEvent
from ..datamodel.types import EnvironmentVariable, LLMCallEventMessage, TeamResult
from ..web.managers.run_context import RunContext
logger = logging.getLogger(__name__)
@ -35,6 +36,10 @@ class RunEventLogger(logging.Handler):
class TeamManager:
"""Manages team operations including loading configs and running teams"""
def __init__(self):
self._team: Optional[BaseGroupChat] = None
self._run_context = RunContext()
@staticmethod
async def load_from_file(path: Union[str, Path]) -> dict:
"""Load team configuration from JSON/YAML file"""
@ -87,17 +92,17 @@ class TeamManager:
for var in env_vars:
os.environ[var.name] = var.value
team: BaseGroupChat = BaseGroupChat.load_component(config)
self._team = BaseGroupChat.load_component(config)
for agent in team._participants:
for agent in self._team._participants:
if hasattr(agent, "input_func") and isinstance(agent, UserProxyAgent) and input_func:
agent.input_func = input_func
return team
return self._team
async def run_stream(
self,
task: str,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None,
team_config: Union[str, Path, dict, ComponentModel],
input_func: Optional[Callable] = None,
cancellation_token: Optional[CancellationToken] = None,
@ -142,7 +147,7 @@ class TeamManager:
async def run(
self,
task: str,
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None,
team_config: Union[str, Path, dict, ComponentModel],
input_func: Optional[Callable] = None,
cancellation_token: Optional[CancellationToken] = None,

View File

@ -1,262 +1,71 @@
import base64
import hashlib
import os
import re
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Sequence
from dotenv import load_dotenv
from autogen_agentchat.messages import ChatMessage, MultiModalMessage, TextMessage
from autogen_core import Image
from autogen_core.models import UserMessage
from loguru import logger
from ..version import APP_NAME
def sha256_hash(text: str) -> str:
def construct_task(query: str, files: list[dict] | None = None) -> Sequence[ChatMessage]:
"""
Compute the SHA-256 hash of a given text.
Construct a task from a query string and list of files.
Returns a list of ChatMessage objects suitable for processing by the agent system.
:param text: The string to hash
:return: The SHA-256 hash of the text, hex-encoded.
Args:
query: The text query from the user
files: List of file objects with properties name, content, and type
Returns:
List of BaseChatMessage objects (TextMessage, MultiModalMessage)
"""
return hashlib.sha256(text.encode()).hexdigest()
if files is None:
files = []
messages = []
def check_and_cast_datetime_fields(obj: Any) -> Any:
if hasattr(obj, "created_at") and isinstance(obj.created_at, str):
obj.created_at = str_to_datetime(obj.created_at)
# Add the user's text query as a TextMessage
if query:
messages.append(TextMessage(source="user", content=query))
if hasattr(obj, "updated_at") and isinstance(obj.updated_at, str):
obj.updated_at = str_to_datetime(obj.updated_at)
return obj
def str_to_datetime(dt_str: str) -> datetime:
if dt_str[-1] == "Z":
# Replace 'Z' with '+00:00' for UTC timezone
dt_str = dt_str[:-1] + "+00:00"
return datetime.fromisoformat(dt_str)
def get_file_type(file_path: str) -> str:
"""
Get file type determined by the file extension. If the file extension is not
recognized, 'unknown' will be used as the file type.
:param file_path: The path to the file to be serialized.
:return: A string containing the file type.
"""
# Extended list of file extensions for code and text files
CODE_EXTENSIONS = {
".py",
".js",
".jsx",
".java",
".c",
".cpp",
".cs",
".ts",
".tsx",
".html",
".css",
".scss",
".less",
".json",
".xml",
".yaml",
".yml",
".md",
".rst",
".tex",
".sh",
".bat",
".ps1",
".php",
".rb",
".go",
".swift",
".kt",
".hs",
".scala",
".lua",
".pl",
".sql",
".config",
}
# Supported spreadsheet extensions
CSV_EXTENSIONS = {".csv", ".xlsx"}
# Supported image extensions
IMAGE_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".gif",
".bmp",
".tiff",
".svg",
".webp",
}
# Supported (web) video extensions
VIDEO_EXTENSIONS = {".mp4", ".webm", ".ogg", ".mov", ".avi", ".wmv"}
# Supported PDF extension
PDF_EXTENSION = ".pdf"
# Determine the file extension
_, file_extension = os.path.splitext(file_path)
# Determine the file type based on the extension
if file_extension in CODE_EXTENSIONS:
file_type = "code"
elif file_extension in CSV_EXTENSIONS:
file_type = "csv"
elif file_extension in IMAGE_EXTENSIONS:
file_type = "image"
elif file_extension == PDF_EXTENSION:
file_type = "pdf"
elif file_extension in VIDEO_EXTENSIONS:
file_type = "video"
else:
file_type = "unknown"
return file_type
def get_modified_files(start_timestamp: float, end_timestamp: float, source_dir: str) -> List[Dict[str, str]]:
"""
Identify files from source_dir that were modified within a specified timestamp range.
The function excludes files with certain file extensions and names.
:param start_timestamp: The floating-point number representing the start timestamp to filter modified files.
:param end_timestamp: The floating-point number representing the end timestamp to filter modified files.
:param source_dir: The directory to search for modified files.
:return: A list of dictionaries with details of relative file paths that were modified.
Dictionary format: {path: "", name: "", extension: "", type: ""}
Files with extensions "__pycache__", "*.pyc", "__init__.py", and "*.cache"
are ignored.
"""
modified_files = []
ignore_extensions = {".pyc", ".cache"}
ignore_files = {"__pycache__", "__init__.py"}
# Walk through the directory tree
for root, dirs, files in os.walk(source_dir):
# Update directories and files to exclude those to be ignored
dirs[:] = [d for d in dirs if d not in ignore_files]
files[:] = [f for f in files if f not in ignore_files and os.path.splitext(f)[1] not in ignore_extensions]
for file in files:
file_path = os.path.join(root, file)
file_mtime = os.path.getmtime(file_path)
# Verify if the file was modified within the given timestamp range
if start_timestamp <= file_mtime <= end_timestamp:
file_relative_path = (
"files/user" + file_path.split("files/user", 1)[1] if "files/user" in file_path else ""
)
file_type = get_file_type(file_path)
file_dict = {
"path": file_relative_path,
"name": os.path.basename(file),
# Remove the dot
"extension": os.path.splitext(file)[1].lstrip("."),
"type": file_type,
}
modified_files.append(file_dict)
# Sort the modified files by extension
modified_files.sort(key=lambda x: x["extension"])
return modified_files
def get_app_root() -> str:
"""
Get the root directory of the application.
:return: The root directory of the application.
"""
app_name = f".{APP_NAME}"
default_app_root = os.path.join(os.path.expanduser("~"), app_name)
if not os.path.exists(default_app_root):
os.makedirs(default_app_root, exist_ok=True)
app_root = os.environ.get("AUTOGENSTUDIO_APPDIR") or default_app_root
return app_root
def get_db_uri(app_root: str) -> str:
"""
Get the default database URI for the application.
:param app_root: The root directory of the application.
:return: The default database URI.
"""
db_uri = f"sqlite:///{os.path.join(app_root, 'database.sqlite')}"
db_uri = os.environ.get("AUTOGENSTUDIO_DATABASE_URI") or db_uri
logger.info(f"Using database URI: {db_uri}")
return db_uri
def init_app_folders(app_file_path: str) -> Dict[str, str]:
"""
Initialize folders needed for a web server, such as static file directories
and user-specific data directories. Also load any .env file if it exists.
:param root_file_path: The root directory where webserver folders will be created
:return: A dictionary with the path of each created folder
"""
app_root = get_app_root()
if not os.path.exists(app_root):
os.makedirs(app_root, exist_ok=True)
# load .env file if it exists
env_file = os.path.join(app_root, ".env")
if os.path.exists(env_file):
logger.info(f"Loaded environment variables from {env_file}")
load_dotenv(env_file)
files_static_root = os.path.join(app_root, "files/")
static_folder_root = os.path.join(app_file_path, "ui")
os.makedirs(files_static_root, exist_ok=True)
os.makedirs(os.path.join(files_static_root, "user"), exist_ok=True)
os.makedirs(static_folder_root, exist_ok=True)
folders = {
"files_static_root": files_static_root,
"static_folder_root": static_folder_root,
"app_root": app_root,
"database_engine_uri": get_db_uri(app_root=app_root),
}
logger.info(f"Initialized application data folder: {app_root}")
return folders
class Version:
def __init__(self, ver_str: str):
# Process each file based on its type
for file in files:
try:
# Split into major.minor.patch
self.major, self.minor, self.patch = map(int, ver_str.split("."))
except (ValueError, AttributeError) as err:
raise ValueError(f"Invalid version format: {ver_str}. Expected: major.minor.patch") from err
if file.get("type", "").startswith("image/"):
# Handle image file using from_base64 method
# The content is already base64 encoded according to the convertFilesToBase64 function
image = Image.from_base64(file["content"])
messages.append(
MultiModalMessage(
source="user", content=[image], metadata={"filename": file.get("name", "unknown.img")}
)
)
elif file.get("type", "").startswith("text/"):
# Handle text file as TextMessage
text_content = base64.b64decode(file["content"]).decode("utf-8")
messages.append(
TextMessage(
source="user", content=text_content, metadata={"filename": file.get("name", "unknown.txt")}
)
)
else:
# Log unsupported file types but still try to process based on best guess
logger.warning(f"Potentially unsupported file type: {file.get('type')} for file {file.get('name')}")
if file.get("type", "").startswith("application/"):
# Try to treat as text if it's an application type (like JSON)
text_content = base64.b64decode(file["content"]).decode("utf-8")
messages.append(
TextMessage(
source="user",
content=text_content,
metadata={
"filename": file.get("name", "unknown.file"),
"filetype": file.get("type", "unknown"),
},
)
)
except Exception as e:
logger.error(f"Error processing file {file.get('name')}: {str(e)}")
# Continue processing other files even if one fails
def __str__(self):
return f"{self.major}.{self.minor}.{self.patch}"
def __eq__(self, other):
if isinstance(other, str):
other = Version(other)
return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
def __gt__(self, other):
if isinstance(other, str):
other = Version(other)
return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch)
return messages

View File

@ -63,7 +63,7 @@ class ComponentTestService:
if status:
logs.append(
f"Agent responded with: {response.chat_message.content} to the question : {test_question}"
f"Agent responded with: {response.chat_message.to_text()} to the question : {test_question}"
)
else:
logs.append("Agent did not return a valid response")

View File

@ -1,3 +1,3 @@
VERSION = "0.4.1"
VERSION = "0.4.2"
__version__ = VERSION
APP_NAME = "autogenstudio"

View File

@ -109,8 +109,8 @@ async def register_auth_dependencies(app: FastAPI, auth_manager: AuthManager) ->
for route in app.routes:
# print(" *** Route: ", route.path)
if hasattr(route, "app") and isinstance(route.app, FastAPI):
route.app.state.auth_manager = auth_manager
if hasattr(route, "app") and isinstance(route.app, FastAPI): # type: ignore
route.app.state.auth_manager = auth_manager # type: ignore
# Manager initialization and cleanup

View File

@ -1 +1 @@
from .connection import WebSocketManager
# from .connection import WebSocketManager

View File

@ -2,12 +2,13 @@ import asyncio
import logging
import traceback
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Sequence, Union
from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.base import TaskResult
from autogen_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
ChatMessage,
HandoffMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
@ -32,6 +33,7 @@ from ...datamodel import (
TeamResult,
)
from ...teammanager import TeamManager
from .run_context import RunContext
logger = logging.getLogger(__name__)
@ -79,86 +81,90 @@ class WebSocketManager:
logger.error(f"Connection error for run {run_id}: {e}")
return False
async def start_stream(self, run_id: int, task: str, team_config: dict) -> None:
async def start_stream(
self, run_id: int, task: str | ChatMessage | Sequence[ChatMessage] | None, team_config: Dict
) -> None:
"""Start streaming task execution with proper run management"""
if run_id not in self._connections or run_id in self._closed_connections:
raise ValueError(f"No active connection for run {run_id}")
team_manager = TeamManager()
cancellation_token = CancellationToken()
self._cancellation_tokens[run_id] = cancellation_token
final_result = None
with RunContext.populate_context(run_id=run_id):
team_manager = TeamManager()
cancellation_token = CancellationToken()
self._cancellation_tokens[run_id] = cancellation_token
final_result = None
try:
# Update run with task and status
run = await self._get_run(run_id)
# get user Settings
user_settings = await self._get_settings(run.user_id)
env_vars = SettingsConfig(**user_settings.config).environment if user_settings else None
if run:
run.task = MessageConfig(content=task, source="user").model_dump()
run.status = RunStatus.ACTIVE
self.db_manager.upsert(run)
try:
# Update run with task and status
run = await self._get_run(run_id)
input_func = self.create_input_func(run_id)
if run is not None and run.user_id:
# get user Settings
user_settings = await self._get_settings(run.user_id)
env_vars = SettingsConfig(**user_settings.config).environment if user_settings else None # type: ignore
run.task = self._convert_images_in_dict(MessageConfig(content=task, source="user").model_dump())
run.status = RunStatus.ACTIVE
self.db_manager.upsert(run)
async for message in team_manager.run_stream(
task=task,
team_config=team_config,
input_func=input_func,
cancellation_token=cancellation_token,
env_vars=env_vars,
):
if cancellation_token.is_cancelled() or run_id in self._closed_connections:
logger.info(f"Stream cancelled or connection closed for run {run_id}")
break
input_func = self.create_input_func(run_id)
formatted_message = self._format_message(message)
if formatted_message:
await self._send_message(run_id, formatted_message)
async for message in team_manager.run_stream(
task=task,
team_config=team_config,
input_func=input_func,
cancellation_token=cancellation_token,
env_vars=env_vars,
):
if cancellation_token.is_cancelled() or run_id in self._closed_connections:
logger.info(f"Stream cancelled or connection closed for run {run_id}")
break
# Save messages by concrete type
if isinstance(
message,
(
TextMessage,
MultiModalMessage,
StopMessage,
HandoffMessage,
ToolCallRequestEvent,
ToolCallExecutionEvent,
LLMCallEventMessage,
),
):
await self._save_message(run_id, message)
# Capture final result if it's a TeamResult
elif isinstance(message, TeamResult):
final_result = message.model_dump()
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
if final_result:
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
formatted_message = self._format_message(message)
if formatted_message:
await self._send_message(run_id, formatted_message)
# Save messages by concrete type
if isinstance(
message,
(
TextMessage,
MultiModalMessage,
StopMessage,
HandoffMessage,
ToolCallRequestEvent,
ToolCallExecutionEvent,
LLMCallEventMessage,
),
):
await self._save_message(run_id, message)
# Capture final result if it's a TeamResult
elif isinstance(message, TeamResult):
final_result = message.model_dump()
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
if final_result:
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
else:
logger.warning(f"No final result captured for completed run {run_id}")
await self._update_run_status(run_id, RunStatus.COMPLETE)
else:
logger.warning(f"No final result captured for completed run {run_id}")
await self._update_run_status(run_id, RunStatus.COMPLETE)
else:
await self._send_message(
run_id,
{
"type": "completion",
"status": "cancelled",
"data": self._cancel_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Update run with cancellation result
await self._update_run(run_id, RunStatus.STOPPED, team_result=self._cancel_message)
await self._send_message(
run_id,
{
"type": "completion",
"status": "cancelled",
"data": self._cancel_message,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
# Update run with cancellation result
await self._update_run(run_id, RunStatus.STOPPED, team_result=self._cancel_message)
except Exception as e:
logger.error(f"Stream error for run {run_id}: {e}")
traceback.print_exc()
await self._handle_stream_error(run_id, e)
finally:
self._cancellation_tokens.pop(run_id, None)
except Exception as e:
logger.error(f"Stream error for run {run_id}: {e}")
traceback.print_exc()
await self._handle_stream_error(run_id, e)
finally:
self._cancellation_tokens.pop(run_id, None)
async def _save_message(
self, run_id: int, message: Union[BaseAgentEvent | BaseChatMessage, BaseChatMessage]
@ -170,7 +176,7 @@ class WebSocketManager:
db_message = Message(
session_id=run.session_id,
run_id=run_id,
config=message.model_dump(),
config=self._convert_images_in_dict(message.model_dump()),
user_id=None, # You might want to pass this from somewhere
)
self.db_manager.upsert(db_message)
@ -183,7 +189,7 @@ class WebSocketManager:
if run:
run.status = status
if team_result:
run.team_result = team_result
run.team_result = self._convert_images_in_dict(team_result)
if error:
run.error_message = error
self.db_manager.upsert(run)
@ -269,6 +275,18 @@ class WebSocketManager:
self._cancellation_tokens.pop(run_id, None)
self._input_responses.pop(run_id, None)
def _convert_images_in_dict(self, obj: Any) -> Any:
"""Recursively find and convert Image objects in dictionaries and lists"""
if isinstance(obj, dict):
return {k: self._convert_images_in_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._convert_images_in_dict(item) for item in obj]
elif isinstance(obj, AGImage): # Assuming you've imported AGImage
# Convert the Image object to a serializable format
return {"type": "image", "url": f"data:image/png;base64,{obj.to_base64()}", "alt": "Image"}
else:
return obj
async def _send_message(self, run_id: int, message: dict) -> None:
"""Send a message through the WebSocket with connection state checking
@ -283,7 +301,7 @@ class WebSocketManager:
try:
if run_id in self._connections:
websocket = self._connections[run_id]
await websocket.send_json(message)
await websocket.send_json(self._convert_images_in_dict(message))
except WebSocketDisconnect:
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
await self.disconnect(run_id)
@ -330,13 +348,20 @@ class WebSocketManager:
try:
if isinstance(message, MultiModalMessage):
message_dump = message.model_dump()
message_dump["content"] = [
message_dump["content"][0],
{
"url": f"data:image/png;base64,{message_dump['content'][1]['data']}",
"alt": "WebSurfer Screenshot",
},
]
message_content = []
for row in message_dump["content"]:
if isinstance(row, dict) and "data" in row:
message_content.append(
{
"url": f"data:image/png;base64,{row['data']}",
"alt": "WebSurfer Screenshot",
}
)
else:
message_content.append(row)
message_dump["content"] = message_content
return {"type": "message", "data": message_dump}
elif isinstance(message, TeamResult):
@ -365,6 +390,7 @@ class WebSocketManager:
except Exception as e:
logger.error(f"Message formatting error: {e}")
traceback.print_exc()
return None
async def _get_run(self, run_id: int) -> Optional[Run]:

View File

@ -0,0 +1,23 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generator
class RunContext:
RUN_CONTEXT_VAR: ClassVar[ContextVar] = ContextVar("RUN_CONTEXT_VAR")
@classmethod
@contextmanager
def populate_context(cls, run_id) -> Generator[None, Any, None]:
token = RunContext.RUN_CONTEXT_VAR.set(run_id)
try:
yield
finally:
RunContext.RUN_CONTEXT_VAR.reset(token)
@classmethod
def current_run_id(cls) -> str:
try:
return cls.RUN_CONTEXT_VAR.get()
except LookupError as e:
raise RuntimeError("Error getting run id") from e

View File

@ -1,10 +1,11 @@
# api/routes/sessions.py
import re
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from ...datamodel import Message, Run, Session
from ...datamodel import Message, Response, Run, Session
from ..deps import get_db
router = APIRouter()
@ -27,12 +28,16 @@ async def get_session(session_id: int, user_id: str, db=Depends(get_db)) -> Dict
@router.post("/")
async def create_session(session: Session, db=Depends(get_db)) -> Dict:
async def create_session(session: Session, db=Depends(get_db)) -> Response:
"""Create a new session"""
response = db.upsert(session)
if not response.status:
raise HTTPException(status_code=400, detail=response.message)
return {"status": True, "data": response.data}
try:
response = db.upsert(session)
if not response.status:
return Response(status=False, message=f"Failed to create session: {response.message}")
return Response(status=True, data=response.data, message="Session created successfully")
except Exception as e:
logger.error(f"Error creating session: {str(e)}")
return Response(status=False, message=f"Failed to create session: {str(e)}")
@router.put("/{session_id}")

View File

@ -14,6 +14,7 @@ router = APIRouter()
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
"""List all teams for a user"""
response = db.get(Team, filters={"user_id": user_id})
if not response.data or len(response.data) == 0:
default_gallery = create_default_gallery()
default_team = Team(user_id=user_id, component=default_gallery.components.teams[0].model_dump())

View File

@ -8,10 +8,11 @@ from fastapi.websockets import WebSocketState
from loguru import logger
from ...datamodel import Run, RunStatus
from ...utils.utils import construct_task
from ..auth.dependencies import get_ws_auth_manager
from ..auth.wsauth import WebSocketAuthHandler
from ..deps import get_db, get_websocket_manager
from ..managers import WebSocketManager
from ..managers.connection import WebSocketManager
router = APIRouter()
@ -26,21 +27,6 @@ async def run_websocket(
):
"""WebSocket endpoint for run communication"""
async def start_stream_wrapper(run_id, task, team_config):
try:
await ws_manager.start_stream(run_id, task, team_config)
except Exception as e:
logger.error(f"Error in start_stream for run {run_id}: {str(e)}")
# Optionally notify the client about the error
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json(
{
"type": "error",
"error": f"Stream processing error: {str(e)}",
"timestamp": datetime.utcnow().isoformat(),
}
)
try:
# Verify run exists before connecting
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
@ -98,11 +84,12 @@ async def run_websocket(
if message.get("type") == "start":
# Handle start message
logger.info(f"Received start request for run {run_id}")
task = message.get("task")
task = construct_task(query=message.get("task"), files=message.get("files"))
team_config = message.get("team_config")
if task and team_config:
# Start the stream in a separate task
asyncio.create_task(start_stream_wrapper(run_id, task, team_config))
asyncio.create_task(ws_manager.start_stream(run_id, task, team_config))
else:
logger.warning(f"Invalid start message format for run {run_id}")
await websocket.send_json(

View File

@ -11,8 +11,12 @@ import {
PanelLeftOpen,
GalleryHorizontalEnd,
Rocket,
Beaker,
LucideBeaker,
FlaskConical,
} from "lucide-react";
import Icon from "./icons";
import { BeakerIcon } from "@heroicons/react/24/outline";
interface INavItem {
name: string;
@ -44,6 +48,12 @@ const navigation: INavItem[] = [
icon: GalleryHorizontalEnd,
breadcrumbs: [{ name: "Gallery", href: "/gallery", current: true }],
},
{
name: "Labs",
href: "/labs",
icon: FlaskConical,
breadcrumbs: [{ name: "Labs", href: "/labs", current: true }],
},
{
name: "Deploy",
href: "/deploy",

View File

@ -42,6 +42,7 @@ export interface FunctionExecutionResult {
export interface BaseMessageConfig {
source: string;
models_usage?: RequestUsage;
metadata?: Record<string, string>;
}
export interface TextMessageConfig extends BaseMessageConfig {
@ -373,7 +374,7 @@ export interface Run {
created_at: string;
updated_at?: string;
status: RunStatus;
task: AgentMessageConfig;
task: AgentMessageConfig[];
team_result: TeamResult | null;
messages: Message[];
error_message?: string;

View File

@ -1,3 +1,4 @@
import { RcFile } from "antd/es/upload";
import { IStatus } from "../types/app";
export const getServerUrl = () => {
@ -116,3 +117,24 @@ export const fetchVersion = () => {
return null;
});
};
export const convertFilesToBase64 = async (files: RcFile[] = []) => {
return Promise.all(
files.map(async (file) => {
return new Promise<{ name: string; content: string; type: string }>(
(resolve, reject) => {
const reader = new FileReader();
reader.onload = () => {
// Extract base64 content from reader result
const base64Content = reader.result as string;
// Remove the data URL prefix (e.g., "data:image/png;base64,")
const base64Data = base64Content.split(",")[1] || base64Content;
resolve({ name: file.name, content: base64Data, type: file.type });
};
reader.onerror = reject;
reader.readAsDataURL(file);
}
);
})
);
};

View File

@ -158,11 +158,11 @@ export const TruncatableText = memo(
{isFullscreen && (
<div
className="fixed inset-0 bg-black/80 z-50 flex items-center justify-center"
className="fixed inset-0 dark:bg-black/80 bg-black/10 z-50 flex items-center justify-center"
onClick={() => setIsFullscreen(false)}
>
<div
className="relative bg-secondary scroll w-full h-full md:w-4/5 md:h-4/5 md:rounded-lg p-8 overflow-auto"
className="relative bg-primary scroll w-full h-full md:w-4/5 md:h-4/5 md:rounded-lg p-8 overflow-auto"
style={{ opacity: 0.95 }}
onClick={(e) => e.stopPropagation()}
>

View File

@ -1,4 +1,4 @@
import React, { useState } from "react";
import React, { useState, useEffect } from "react";
import { Tabs, Button, Tooltip, Drawer, Input } from "antd";
import {
Package,
@ -12,6 +12,7 @@ import {
Copy,
Trash,
Plus,
Download,
} from "lucide-react";
import { ComponentEditor } from "../teambuilder/builder/component-editor/component-editor";
import { TruncatableText } from "../atoms";
@ -160,6 +161,13 @@ export const GalleryDetail: React.FC<{
gallery.config.metadata.description
);
useEffect(() => {
setTempName(gallery.config.name);
setTempDescription(gallery.config.metadata.description);
setActiveTab("team");
setEditingComponent(null);
}, [gallery.id]);
const updateGallery = (
category: CategoryKey,
updater: (
@ -286,6 +294,21 @@ export const GalleryDetail: React.FC<{
setIsEditingDetails(false);
};
const handleDownload = () => {
const dataStr = JSON.stringify(gallery, null, 2);
const dataBlob = new Blob([dataStr], { type: "application/json" });
const url = URL.createObjectURL(dataBlob);
const link = document.createElement("a");
link.href = url;
link.download = `${gallery.config.name
.toLowerCase()
.replace(/\s+/g, "_")}.json`;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
};
const tabItems = Object.entries(iconMap).map(([key, Icon]) => ({
key,
label: (
@ -355,25 +378,6 @@ export const GalleryDetail: React.FC<{
</Tooltip>
)}
</div>
{!isEditingDetails ? (
<Button
icon={<Edit className="w-4 h-4" />}
onClick={() => setIsEditingDetails(true)}
type="text"
className="text-white hover:text-white/80"
>
Edit
</Button>
) : (
<div className="flex gap-2">
<Button onClick={() => setIsEditingDetails(false)}>
Cancel
</Button>
<Button type="primary" onClick={handleDetailsSave}>
Save
</Button>
</div>
)}
</div>
{isEditingDetails ? (
<TextArea
@ -383,9 +387,39 @@ export const GalleryDetail: React.FC<{
rows={2}
/>
) : (
<p className="text-secondary w-1/2 mt-2 line-clamp-2">
{gallery.config.metadata.description}
</p>
<div className="flex flex-col gap-2">
<p className="text-secondary w-1/2 mt-2 line-clamp-2">
{gallery.config.metadata.description}
</p>
<div className="flex gap-0">
<Tooltip title="Edit Gallery">
<Button
icon={<Edit className="w-4 h-4" />}
onClick={() => setIsEditingDetails(true)}
type="text"
className="text-white hover:text-white/80"
/>
</Tooltip>
<Tooltip title="Download Gallery">
<Button
icon={<Download className="w-4 h-4" />}
onClick={handleDownload}
type="text"
className="text-white hover:text-white/80"
/>
</Tooltip>
</div>
</div>
)}
{isEditingDetails && (
<div className="flex gap-2 mt-2">
<Button onClick={() => setIsEditingDetails(false)}>
Cancel
</Button>
<Button type="primary" onClick={handleDetailsSave}>
Save
</Button>
</div>
)}
</div>
<div className="flex gap-2">

View File

@ -0,0 +1,27 @@
import React from "react";
import { Alert } from "antd";
import { copyToClipboard } from "./guides";
import { Download } from "lucide-react";
const ComponentLab: React.FC = () => {
return (
<div className="">
<h1 className="tdext-2xl font-bold mb-6">
Using AutoGen Studio Teams in Python Code and REST API
</h1>
<Alert
className="mb-6"
message="Prerequisites"
description={
<ul className="list-disc pl-4 mt-2 space-y-1">
<li>AutoGen Studio installed</li>
</ul>
}
type="info"
/>
</div>
);
};
export default ComponentLab;

View File

@ -0,0 +1,27 @@
import React from "react";
import { Lab } from "../types";
import ComponentLab from "./component";
interface LabContentProps {
lab: Lab;
}
export const copyToClipboard = (text: string) => {
navigator.clipboard.writeText(text);
};
export const LabContent: React.FC<LabContentProps> = ({ lab }) => {
// Render different content based on guide type and id
switch (lab.id) {
case "python-setup":
return <ComponentLab />;
default:
return (
<div className="text-secondary">
A Lab with the title <strong>{lab.title}</strong> is work in progress!
</div>
);
}
};
export default LabContent;

View File

@ -0,0 +1,87 @@
import React, { useState, useEffect } from "react";
import { ChevronRight, TriangleAlert } from "lucide-react";
import { LabsSidebar } from "./sidebar";
import { Lab, defaultLabs } from "./types";
import { LabContent } from "./labs/guides";
export const LabsManager: React.FC = () => {
const [isLoading, setIsLoading] = useState(false);
const [labs, setLabs] = useState<Lab[]>([]);
const [currentLab, setcurrentLab] = useState<Lab | null>(null);
const [isSidebarOpen, setIsSidebarOpen] = useState(() => {
if (typeof window !== "undefined") {
const stored = localStorage.getItem("labsSidebar");
return stored !== null ? JSON.parse(stored) : true;
}
return true;
});
// Persist sidebar state
useEffect(() => {
if (typeof window !== "undefined") {
localStorage.setItem("labsSidebar", JSON.stringify(isSidebarOpen));
}
}, [isSidebarOpen]);
// Set first guide as current if none selected
useEffect(() => {
if (!currentLab && labs.length > 0) {
setcurrentLab(labs[0]);
}
}, [labs, currentLab]);
return (
<div className="relative flex h-full w-full">
{/* Sidebar */}
<div
className={`absolute left-0 top-0 h-full transition-all duration-200 ease-in-out ${
isSidebarOpen ? "w-64" : "w-12"
}`}
>
<LabsSidebar
isOpen={isSidebarOpen}
labs={labs}
currentLab={currentLab}
onToggle={() => setIsSidebarOpen(!isSidebarOpen)}
onSelectLab={setcurrentLab}
isLoading={isLoading}
/>
</div>
{/* Main Content */}
<div
className={`flex-1 transition-all max-w-5xl -mr-6 duration-200 ${
isSidebarOpen ? "ml-64" : "ml-12"
}`}
>
<div className="p-4 pt-2">
{/* Breadcrumb */}
<div className="flex items-center gap-2 mb-4 text-sm">
<span className="text-primary font-medium">Labs</span>
{currentLab && (
<>
<ChevronRight className="w-4 h-4 text-secondary" />
<span className="text-secondary">{currentLab.title}</span>
</>
)}
</div>
<div className="rounded border border-secondary border-dashed p-2 text-sm mb-4">
<TriangleAlert className="w-4 h-4 inline-block mr-2 -mt-1 text-secondary " />{" "}
Labs is designed to host experimental features for building and
debugging multiagent applications.
</div>
{/* Content Area */}
{currentLab ? (
<LabContent lab={currentLab} />
) : (
<div className="flex items-center justify-center h-[calc(100vh-190px)] text-secondary">
Select a lab from the sidebar to get started
</div>
)}
</div>
</div>
</div>
);
};
export default LabsManager;

View File

@ -0,0 +1,111 @@
import React from "react";
import { Button, Tooltip } from "antd";
import {
PanelLeftClose,
PanelLeftOpen,
Book,
InfoIcon,
RefreshCcw,
} from "lucide-react";
import type { Lab } from "./types";
interface LabsSidebarProps {
isOpen: boolean;
labs: Lab[];
currentLab: Lab | null;
onToggle: () => void;
onSelectLab: (guide: Lab) => void;
isLoading?: boolean;
}
export const LabsSidebar: React.FC<LabsSidebarProps> = ({
isOpen,
labs,
currentLab,
onToggle,
onSelectLab,
isLoading = false,
}) => {
// Render collapsed state
if (!isOpen) {
return (
<div className="h-full border-r border-secondary">
<div className="p-2 -ml-2">
<Tooltip title="Documentation">
<button
onClick={onToggle}
className="p-2 rounded-md hover:bg-secondary hover:text-accent text-secondary transition-colors focus:outline-none focus:ring-2 focus:ring-accent focus:ring-opacity-50"
>
<PanelLeftOpen strokeWidth={1.5} className="h-6 w-6" />
</button>
</Tooltip>
</div>
</div>
);
}
return (
<div className="h-full border-r border-secondary">
{/* Header */}
<div className="flex items-center justify-between pt-0 p-4 pl-2 pr-2 border-b border-secondary">
<div className="flex items-center gap-2">
{/* <Book className="w-4 h-4" /> */}
<span className="text-primary font-medium">Labs</span>
{/* <span className="px-2 py-0.5 text-xs bg-accent/10 text-accent rounded">
{guides.length}
</span> */}
</div>
<Tooltip title="Close Sidebar">
<button
onClick={onToggle}
className="p-2 rounded-md hover:bg-secondary hover:text-accent text-secondary transition-colors focus:outline-none focus:ring-2 focus:ring-accent focus:ring-opacity-50"
>
<PanelLeftClose strokeWidth={1.5} className="h-6 w-6" />
</button>
</Tooltip>
</div>
{/* Loading State */}
{isLoading && (
<div className="p-4">
<RefreshCcw className="w-4 h-4 inline-block animate-spin" />
</div>
)}
{/* Empty State */}
{!isLoading && labs.length === 0 && (
<div className="p-2 mt-2 mr-2 text-center text-secondary text-sm border border-dashed rounded">
<InfoIcon className="w-4 h-4 inline-block mr-1.5 -mt-0.5" />
No labs available. Please check back later.
</div>
)}
{/* Guides List */}
<div className="overflow-y-auto h-[calc(100%-64px)] mt-4">
{labs.map((lab) => (
<div key={lab.id} className="relative">
<div
className={`absolute top-1 left-0.5 z-50 h-[calc(100%-8px)]
w-1 bg-opacity-80 rounded ${
currentLab?.id === lab.id ? "bg-accent" : "bg-tertiary"
}`}
/>
<div
className={`group ml-1 flex flex-col p-2 rounded-l cursor-pointer hover:bg-secondary ${
currentLab?.id === lab.id
? "border-accent bg-secondary"
: "border-transparent"
}`}
onClick={() => onSelectLab(lab)}
>
{/* Guide Title */}
<div className="flex items-center justify-between">
<span className="text-sm truncate">{lab.title}</span>
</div>
</div>
</div>
))}
</div>
</div>
);
};

View File

@ -0,0 +1,13 @@
export interface Lab {
id: string;
title: string;
type: "python" | "docker" | "cloud";
}
export const defaultLabs: Lab[] = [
{
id: "component-builder",
title: "Component Builder",
type: "python",
},
];

View File

@ -1,6 +1,6 @@
import * as React from "react";
import { Button, message, Tooltip } from "antd";
import { getServerUrl } from "../../../utils/utils";
import { convertFilesToBase64, getServerUrl } from "../../../utils/utils";
import { IStatus } from "../../../types/app";
import {
Run,
@ -27,6 +27,7 @@ import {
X,
} from "lucide-react";
import SessionDropdown from "./sessiondropdown";
import { RcFile } from "antd/es/upload";
const logo = require("../../../../images/landing/welcome.svg").default;
interface ChatViewProps {
@ -395,7 +396,7 @@ export default function ChatView({
}
};
const runTask = async (query: string) => {
const runTask = async (query: string, files: RcFile[] = []) => {
setError(null);
setLoading(true);
@ -405,13 +406,13 @@ export default function ChatView({
setActiveSocket(null);
activeSocketRef.current = null;
}
if (inputTimeoutRef.current) {
clearTimeout(inputTimeoutRef.current);
inputTimeoutRef.current = null;
}
if (!session?.id || !teamConfig) {
// Add teamConfig check
setLoading(false);
return;
}
@ -419,6 +420,9 @@ export default function ChatView({
try {
const runId = await createRun(session.id);
// Process files using the extracted function
const processedFiles = await convertFilesToBase64(files);
// Initialize run state BEFORE websocket connection
setCurrentRun({
id: runId,
@ -433,8 +437,8 @@ export default function ChatView({
error_message: undefined,
});
// Setup WebSocket
const socket = setupWebSocket(runId, query);
// Setup WebSocket with files
const socket = setupWebSocket(runId, query, processedFiles);
setActiveSocket(socket);
activeSocketRef.current = socket;
} catch (error) {
@ -444,7 +448,11 @@ export default function ChatView({
}
};
const setupWebSocket = (runId: number, query: string): WebSocket => {
const setupWebSocket = (
runId: number,
query: string,
files: { name: string; type: string; content: string }[]
): WebSocket => {
if (!session || !session.id) {
throw new Error("Invalid session configuration");
}
@ -465,6 +473,7 @@ export default function ChatView({
id: runId,
created_at: new Date().toISOString(),
status: "active",
task: createMessage(
{ content: query, source: "user" },
runId,
@ -481,6 +490,7 @@ export default function ChatView({
JSON.stringify({
type: "start",
task: query,
files: files,
team_config: teamConfig,
})
);
@ -657,7 +667,10 @@ export default function ChatView({
onSubmit={runTask}
loading={loading}
error={error}
disabled={currentRun?.status === "awaiting_input"}
disabled={
currentRun?.status === "awaiting_input" ||
currentRun?.status === "active"
}
/>
</div>
)}

View File

@ -1,5 +1,3 @@
"use client";
import {
PaperAirplaneIcon,
Cog6ToothIcon,
@ -7,9 +5,33 @@ import {
} from "@heroicons/react/24/outline";
import * as React from "react";
import { IStatus } from "../../../types/app";
import { Upload, message, Button, Tooltip, notification } from "antd";
import type { UploadFile, UploadProps, RcFile } from "antd/es/upload/interface";
import {
FileTextIcon,
ImageIcon,
Paperclip,
UploadIcon,
XIcon,
} from "lucide-react";
import { truncateText } from "../../../utils/utils";
// Maximum file size in bytes (5MB)
const MAX_FILE_SIZE = 5 * 1024 * 1024;
// Allowed file types
const ALLOWED_FILE_TYPES = [
"text/plain",
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
];
// Threshold for large text files (in characters)
const LARGE_TEXT_THRESHOLD = 1500;
interface ChatInputProps {
onSubmit: (text: string) => void;
onSubmit: (text: string, files: RcFile[]) => void;
loading: boolean;
error: IStatus | null;
disabled?: boolean;
@ -23,7 +45,11 @@ export default function ChatInput({
}: ChatInputProps) {
const textAreaRef = React.useRef<HTMLTextAreaElement>(null);
const [previousLoading, setPreviousLoading] = React.useState(loading);
const [text, setText] = React.useState("");
const [text, setText] = React.useState("What is the capital of France?");
const [fileList, setFileList] = React.useState<UploadFile[]>([]);
const [dragOver, setDragOver] = React.useState(false);
const [notificationApi, notificationContextHolder] =
notification.useNotification();
const textAreaDefaultHeight = "64px";
const isInputDisabled = disabled || loading;
@ -31,7 +57,9 @@ export default function ChatInput({
// Handle textarea auto-resize
React.useEffect(() => {
if (textAreaRef.current) {
textAreaRef.current.style.height = textAreaDefaultHeight;
// Temporarily set height to auto to get proper scrollHeight
textAreaRef.current.style.height = "auto";
// Then set to the scroll height
const scrollHeight = textAreaRef.current.scrollHeight;
textAreaRef.current.style.height = `${scrollHeight}px`;
}
@ -45,11 +73,139 @@ export default function ChatInput({
setPreviousLoading(loading);
}, [loading, error, previousLoading]);
// Add paste event listener
React.useEffect(() => {
const handlePaste = (e: ClipboardEvent) => {
if (isInputDisabled) return;
// Handle image paste
if (e.clipboardData?.items) {
let hasImageItem = false;
for (let i = 0; i < e.clipboardData.items.length; i++) {
const item = e.clipboardData.items[i];
// Handle image items
if (item.type.indexOf("image/") === 0) {
hasImageItem = true;
const file = item.getAsFile();
if (file && file.size <= MAX_FILE_SIZE) {
// Prevent the default paste behavior for images
e.preventDefault();
// Create a unique file name
const fileName = `pasted-image-${new Date().getTime()}.png`;
// Create a new File with a proper name
const namedFile = new File([file], fileName, { type: file.type });
// Convert to the expected UploadFile format
const uploadFile: UploadFile = {
uid: `paste-${Date.now()}`,
name: fileName,
status: "done",
size: namedFile.size,
type: namedFile.type,
originFileObj: namedFile as RcFile,
};
// Add to file list
setFileList((prev) => [...prev, uploadFile]);
// Show successful paste notification
message.success(`Image pasted successfully`);
} else if (file && file.size > MAX_FILE_SIZE) {
message.error(`Pasted image is too large. Maximum size is 5MB.`);
}
}
// Handle text items - only if there's a large amount of text
if (item.type === "text/plain" && !hasImageItem) {
item.getAsString((text) => {
// Only process for large text
if (text.length > LARGE_TEXT_THRESHOLD) {
// We need to prevent the default paste behavior
// But since we're in an async callback, we need to
// manually clear the textarea's selection value
setTimeout(() => {
if (textAreaRef.current) {
const currentValue = textAreaRef.current.value;
const selectionStart =
textAreaRef.current.selectionStart || 0;
const selectionEnd = textAreaRef.current.selectionEnd || 0;
// Remove the pasted text from the textarea
const newValue =
currentValue.substring(0, selectionStart - text.length) +
currentValue.substring(selectionEnd);
// Update the textarea
textAreaRef.current.value = newValue;
// Trigger the onChange event manually
setText(newValue);
}
}, 0);
// Prevent default paste for large text
e.preventDefault();
// Create a text file from the pasted content
const blob = new Blob([text], { type: "text/plain" });
const file = new File(
[blob],
`pasted-text-${new Date().getTime()}.txt`,
{ type: "text/plain" }
);
// Add to file list
const uploadFile: UploadFile = {
uid: `paste-${Date.now()}`,
name: file.name,
status: "done",
size: file.size,
type: file.type,
originFileObj: file as RcFile,
};
setFileList((prev) => [...prev, uploadFile]);
// Notify user about the conversion
notificationApi.info({
message: (
<span className="text-sm">
Large Text Converted to File
</span>
),
description: (
<span className="text-sm text-secondary">
Your pasted text has been attached as a file.
</span>
),
duration: 3,
});
}
});
}
}
}
};
// Add the paste event listener to the document
document.addEventListener("paste", handlePaste);
// Cleanup
return () => {
document.removeEventListener("paste", handlePaste);
};
}, [isInputDisabled, notificationApi]);
const resetInput = () => {
if (textAreaRef.current) {
textAreaRef.current.value = "";
textAreaRef.current.style.height = textAreaDefaultHeight;
setText("");
setFileList([]);
}
};
@ -58,9 +214,18 @@ export default function ChatInput({
};
const handleSubmit = () => {
if (textAreaRef.current?.value && !isInputDisabled) {
const query = textAreaRef.current.value;
onSubmit(query);
if (
(textAreaRef.current?.value || fileList.length > 0) &&
!isInputDisabled
) {
const query = textAreaRef.current?.value || "";
// Get all valid RcFile objects
const files = fileList
.filter((file) => file.originFileObj)
.map((file) => file.originFileObj as RcFile);
onSubmit(query, files);
}
};
@ -71,12 +236,161 @@ export default function ChatInput({
}
};
const uploadProps: UploadProps = {
name: "file",
multiple: true,
fileList,
beforeUpload: (file) => {
// Check file size
if (file.size > MAX_FILE_SIZE) {
message.error(`${file.name} is too large. Maximum size is 5MB.`);
return Upload.LIST_IGNORE;
}
// Check file type
if (!ALLOWED_FILE_TYPES.includes(file.type)) {
notificationApi.warning({
message: <span className="text-sm">Unsupported File Type</span>,
description: (
<span className="text-sm text-secondary">
Please upload only text (.txt) or images (.jpg, .png, .gif, .svg)
files.
</span>
),
showProgress: true,
duration: 8.5,
});
return Upload.LIST_IGNORE;
}
// Correctly set the uploadFile with originFileObj property
const uploadFile: UploadFile = {
uid: file.uid,
name: file.name,
status: "done",
size: file.size,
type: file.type,
originFileObj: file,
};
setFileList((prev) => [...prev, uploadFile]);
return false; // Prevent automatic upload
},
onRemove: (file) => {
setFileList(fileList.filter((item) => item.uid !== file.uid));
},
showUploadList: false, // We'll handle our own custom file preview
customRequest: ({ onSuccess }) => {
// Mock successful upload since we're not actually uploading anywhere yet
if (onSuccess) onSuccess("ok");
},
};
const getFileIcon = (file: UploadFile) => {
const fileType = file.type || "";
if (fileType.startsWith("image/")) {
return <ImageIcon className="w-4 h-4" />;
}
return <FileTextIcon className="w-4 h-4" />;
};
// Add these new event handler functions to the component
const handleDragOver = (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
if (!isInputDisabled) {
setDragOver(true);
}
};
const handleDragLeave = (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setDragOver(false);
};
const handleDrop = (e: React.DragEvent) => {
e.preventDefault();
e.stopPropagation();
setDragOver(false);
if (isInputDisabled) return;
const droppedFiles = Array.from(e.dataTransfer.files);
droppedFiles.forEach((file) => {
// Check file size
if (file.size > MAX_FILE_SIZE) {
message.error(`${file.name} is too large. Maximum size is 5MB.`);
return;
}
// Check file type
if (!ALLOWED_FILE_TYPES.includes(file.type)) {
notificationApi.warning({
message: <span className="text-sm">Unsupported File Type</span>,
description: (
<span className="text-sm text-secondary">
Please upload only text (.txt) or images (.jpg, .png, .gif, .svg)
files.
</span>
),
showProgress: true,
duration: 8.5,
});
return;
}
// Add to file list
const uploadFile: UploadFile = {
uid: `file-${Date.now()}-${file.name}`,
name: file.name,
status: "done",
size: file.size,
type: file.type,
originFileObj: file as RcFile,
};
setFileList((prev) => [...prev, uploadFile]);
});
};
return (
<div className="mt-2 w-full">
{notificationContextHolder}
{/* File previews */}
{fileList.length > 0 && (
<div className="-mb-2 mx-1 bg-tertiary rounded-t border-b-0 p-2 flex bodrder flex-wrap gap-2">
{fileList.map((file) => (
<div
key={file.uid}
className="flex items-center gap-1 bg-secondary rounded px-2 py-1 text-xs"
>
{getFileIcon(file)}
<span className="truncate max-w-[150px]">
{truncateText(file.name, 20)}
</span>
<Button
type="text"
size="small"
className="p-0 ml-1 flex items-center justify-center"
onClick={() =>
setFileList((prev) => prev.filter((f) => f.uid !== file.uid))
}
icon={<XIcon className="w-3 h-3" />}
/>
</div>
))}
</div>
)}
<div
className={`mt-2 rounded shadow-sm flex mb-1 ${
isInputDisabled ? "opacity-50" : ""
}`}
className={`mt-2 rounded shadow-sm flex mb-1 transition-all duration-200 ${
dragOver ? "ring-2 ring-blue-400" : ""
} ${isInputDisabled ? "opacity-50" : ""}`}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDrop}
>
<form
className="flex-1 relative"
@ -85,38 +399,84 @@ export default function ChatInput({
handleSubmit();
}}
>
{dragOver && (
<div className="absolute inset-0 bg-blue-100 bg-opacity-60 flex items-center justify-center rounded z-10 pointer-events-none">
<div className="text-accent tex-xs items-center">
<UploadIcon className="h-4 w-4 mr-2 inline-block" />
<span className="text-xs inline-block">Drop files here</span>
</div>
</div>
)}
<textarea
id="queryInput"
name="queryInput"
ref={textAreaRef}
defaultValue={"what is the height of the eiffel tower"}
value={text}
onChange={handleTextChange}
onKeyDown={handleKeyDown}
className={`flex items-center w-full resize-none text-gray-600 rounded border border-accent bg-white p-2 pl-5 pr-16 ${
isInputDisabled ? "cursor-not-allowed" : ""
}`}
className={`flex items-center w-full resize-none text-gray-600 rounded ${
dragOver
? "border-2 border-blue-500 bg-blue-50"
: "border border-accent bg-white"
} p-2 pl-5 pr-16 ${isInputDisabled ? "cursor-not-allowed" : ""}`}
style={{
maxHeight: "120px",
overflowY: "auto",
minHeight: "50px",
transition: "all 0.2s ease-in-out",
}}
placeholder="Type your message here..."
placeholder={
dragOver ? "Drop files here..." : "Type your message here..."
}
disabled={isInputDisabled}
/>
<button
type="button"
onClick={handleSubmit}
disabled={isInputDisabled}
className={`absolute right-3 bottom-2 bg-accent transition duration-300 rounded flex justify-center items-center w-11 h-9 ${
isInputDisabled ? "cursor-not-allowed" : "hover:brightness-75"
}`}
>
{loading ? (
<Cog6ToothIcon className="text-white animate-spin rounded-full h-6 w-6" />
) : (
<PaperAirplaneIcon className="h-6 w-6 text-white" />
)}
</button>
<div className={`absolute right-3 bottom-2 flex gap-2`}>
<div
className={` ${
disabled || isInputDisabled
? " opacity-50 pointer-events-none "
: ""
}`}
>
{" "}
<Upload className="zero-padding-upload " {...uploadProps}>
<Tooltip
title=<span className="text-sm">
Upload File{" "}
<span className="text-secondary text-xs">(max 5mb)</span>
</span>
placement="top"
>
<Button type="text" disabled={isInputDisabled} className=" ">
<UploadIcon
strokeWidth={2}
size={26}
className="p-1 inline-block w-8 text-accent"
/>
</Button>
</Tooltip>
</Upload>
</div>
<button
type="button"
onClick={handleSubmit}
disabled={
isInputDisabled || (text.trim() === "" && fileList.length === 0)
}
className={`bg-accent transition duration-300 rounded flex justify-center items-center w-11 h-9 ${
isInputDisabled || (text.trim() === "" && fileList.length === 0)
? "cursor-not-allowed opacity-50"
: "hover:brightness-75"
}`}
>
{loading ? (
<Cog6ToothIcon className="text-white animate-spin rounded-full h-6 w-6" />
) : (
<PaperAirplaneIcon className="h-6 w-6 text-white" />
)}
</button>
</div>
</form>
</div>

View File

@ -25,9 +25,10 @@ const getImageSource = (item: ImageContent): string => {
return "/api/placeholder/400/320";
};
const RenderMultiModal: React.FC<{ content: (string | ImageContent)[] }> = ({
content,
}) => (
const RenderMultiModal: React.FC<{
content: (string | ImageContent)[];
thumbnail?: boolean;
}> = ({ content, thumbnail = false }) => (
<div className="space-y-2">
{content.map((item, index) =>
typeof item === "string" ? (
@ -37,7 +38,9 @@ const RenderMultiModal: React.FC<{ content: (string | ImageContent)[] }> = ({
key={index}
src={getImageSource(item)}
alt={item.alt || "Image"}
className="w-full h-auto rounded border border-secondary"
className={` h-auto rounded border border-secondary ${
thumbnail ? "w-24 h-24 " : " w-full "
}`}
/>
)
)}
@ -101,6 +104,18 @@ export const messageUtils = {
);
},
isNestedMessageContent(content: unknown): content is AgentMessageConfig[] {
if (!Array.isArray(content)) return false;
return content.every(
(item) =>
typeof item === "object" &&
item !== null &&
"source" in item &&
"content" in item &&
"type" in item
);
},
isMultiModalContent(content: unknown): content is (string | ImageContent)[] {
if (!Array.isArray(content)) return false;
return content.every(
@ -128,20 +143,66 @@ export const messageUtils = {
isUser(source: string): boolean {
return source === "user";
},
isMessageArray(
message: AgentMessageConfig | AgentMessageConfig[]
): message is AgentMessageConfig[] {
return Array.isArray(message);
},
};
interface MessageProps {
message: AgentMessageConfig;
message: AgentMessageConfig | AgentMessageConfig[];
isLast?: boolean;
className?: string;
}
export const RenderNestedMessages: React.FC<{
content: AgentMessageConfig[];
}> = ({ content }) => (
<div className="space-y-4">
{content.map((item, index) => (
<div
key={index}
className={`${
index > 0 ? "bordper border-secondary rounded bg-secondary/30" : ""
}`}
>
{typeof item.content === "string" ? (
<TruncatableText
content={item.content}
className={`break-all ${index === 0 ? "text-base" : "text-sm"}`}
/>
) : messageUtils.isMultiModalContent(item.content) ? (
<RenderMultiModal content={item.content} thumbnail />
) : (
<pre className="text-xs whitespace-pre-wrap overflow-x-auto">
{JSON.stringify(item.content, null, 2)}
</pre>
)}
</div>
))}
</div>
);
export const RenderMessage: React.FC<MessageProps> = ({
message,
isLast = false,
className = "",
}) => {
if (!message) return null;
// If message is an array, render the first message or return null
if (messageUtils.isMessageArray(message)) {
return message.length > 0 ? (
<RenderMessage
message={message[0]}
isLast={isLast}
className={className}
/>
) : null;
}
const isUser = messageUtils.isUser(message.source);
const content = message.content;
const isLLMEventMessage = message.source === "llm_call_event";
@ -186,7 +247,9 @@ export const RenderMessage: React.FC<MessageProps> = ({
{messageUtils.isToolCallContent(content) ? (
<RenderToolCall content={content} />
) : messageUtils.isMultiModalContent(content) ? (
<RenderMultiModal content={content} />
<RenderMultiModal content={content} thumbnail />
) : messageUtils.isNestedMessageContent(content) ? (
<RenderNestedMessages content={content} />
) : messageUtils.isFunctionExecutionResult(content) ? (
<RenderToolResult content={content} />
) : message.source === "llm_call_event" ? (
@ -198,7 +261,6 @@ export const RenderMessage: React.FC<MessageProps> = ({
/>
)}
</div>
{message.models_usage && (
<div className="text-xs text-secondary mt-1">
Tokens:{" "}

View File

@ -113,6 +113,8 @@ const RunView: React.FC<RunViewProps> = ({
return run.messages.filter((msg) => msg.config.source !== "llm_call_event");
}, [run.messages, uiSettings.show_llm_call_events]);
console.log("Run task", run.task);
// Replace existing scroll effect with this simpler one
useEffect(() => {
setTimeout(() => {

View File

@ -33,7 +33,7 @@ import {
} from "lucide-react";
import { useTeamBuilderStore } from "./store";
import { ComponentLibrary } from "./library";
import { ComponentTypes, Team } from "../../../types/datamodel";
import { ComponentTypes, Gallery, Team } from "../../../types/datamodel";
import { CustomNode, CustomEdge, DragItem } from "./types";
import { edgeTypes, nodeTypes } from "./nodes";
@ -46,7 +46,7 @@ import TestDrawer from "./testdrawer";
import { validationAPI, ValidationResponse } from "../api";
import { ValidationErrors } from "./validationerrors";
import ComponentEditor from "./component-editor/component-editor";
import { useGalleryStore } from "../../gallery/store";
// import { useGalleryStore } from "../../gallery/store";
const { Sider, Content } = Layout;
interface DragItemData {
@ -60,12 +60,14 @@ interface TeamBuilderProps {
team: Team;
onChange?: (team: Partial<Team>) => void;
onDirtyStateChange?: (isDirty: boolean) => void;
selectedGallery?: Gallery | null;
}
export const TeamBuilder: React.FC<TeamBuilderProps> = ({
team,
onChange,
onDirtyStateChange,
selectedGallery,
}) => {
// Replace store state with React Flow hooks
const [nodes, setNodes, onNodesChange] = useNodesState<CustomNode>([]);
@ -86,7 +88,7 @@ export const TeamBuilder: React.FC<TeamBuilderProps> = ({
const [validationLoading, setValidationLoading] = useState(false);
const [testDrawerVisible, setTestDrawerVisible] = useState(false);
const defaultGallery = useGalleryStore((state) => state.getSelectedGallery());
const {
undo,
redo,
@ -465,8 +467,8 @@ export const TeamBuilder: React.FC<TeamBuilderProps> = ({
onDragStart={handleDragStart}
>
<Layout className=" relative bg-primary h-[calc(100vh-239px)] rounded">
{!isJsonMode && defaultGallery && (
<ComponentLibrary defaultGallery={defaultGallery} />
{!isJsonMode && selectedGallery && (
<ComponentLibrary defaultGallery={selectedGallery} />
)}
<Layout className="bg-primary rounded">

View File

@ -5,7 +5,7 @@ import { appContext } from "../../../hooks/provider";
import { teamAPI } from "./api";
import { useGalleryStore } from "../gallery/store";
import { TeamSidebar } from "./sidebar";
import type { Team } from "../../types/datamodel";
import { Gallery, type Team } from "../../types/datamodel";
import { TeamBuilder } from "./builder/builder";
export const TeamManager: React.FC = () => {
@ -19,18 +19,12 @@ export const TeamManager: React.FC = () => {
}
});
const [selectedGallery, setSelectedGallery] = useState<Gallery | null>(null);
const { user } = useContext(appContext);
const [messageApi, contextHolder] = message.useMessage();
const [hasUnsavedChanges, setHasUnsavedChanges] = useState(false);
// Initialize galleries
const fetchGalleries = useGalleryStore((state) => state.fetchGalleries);
useEffect(() => {
if (user?.id) {
fetchGalleries(user.id);
}
}, [user?.id, fetchGalleries]);
// Persist sidebar state
useEffect(() => {
if (typeof window !== "undefined") {
@ -171,6 +165,8 @@ export const TeamManager: React.FC = () => {
onEditTeam={setCurrentTeam}
onDeleteTeam={handleDeleteTeam}
isLoading={isLoading}
setSelectedGallery={setSelectedGallery}
selectedGallery={selectedGallery}
/>
</div>
@ -205,6 +201,7 @@ export const TeamManager: React.FC = () => {
team={currentTeam}
onChange={handleSaveTeam}
onDirtyStateChange={setHasUnsavedChanges}
selectedGallery={selectedGallery}
/>
) : (
<div className="flex items-center justify-center h-[calc(100vh-190px)] text-secondary">

View File

@ -1,4 +1,4 @@
import React, { useState } from "react";
import React, { useContext, useState } from "react";
import { Button, Tooltip, Select, message } from "antd";
import {
Bot,
@ -12,9 +12,12 @@ import {
RefreshCcw,
History,
} from "lucide-react";
import type { Team } from "../../types/datamodel";
import type { Gallery, Team } from "../../types/datamodel";
import { getRelativeTimeString } from "../atoms";
import { useGalleryStore } from "../gallery/store";
import { GalleryAPI } from "../gallery/api";
import { appContext } from "../../../hooks/provider";
import { Link } from "gatsby";
import { getLocalStorage, setLocalStorage } from "../../utils/utils";
interface TeamSidebarProps {
isOpen: boolean;
@ -26,6 +29,8 @@ interface TeamSidebarProps {
onEditTeam: (team: Team) => void;
onDeleteTeam: (teamId: number) => void;
isLoading?: boolean;
selectedGallery: Gallery | null;
setSelectedGallery: (gallery: Gallery) => void;
}
export const TeamSidebar: React.FC<TeamSidebarProps> = ({
@ -38,18 +43,50 @@ export const TeamSidebar: React.FC<TeamSidebarProps> = ({
onEditTeam,
onDeleteTeam,
isLoading = false,
selectedGallery,
setSelectedGallery,
}) => {
// Tab state - "recent" or "gallery"
const [activeTab, setActiveTab] = useState<"recent" | "gallery">("recent");
const [messageApi, contextHolder] = message.useMessage();
// Gallery store
const {
galleries,
selectedGallery,
selectGallery,
isLoading: isLoadingGalleries,
} = useGalleryStore();
const [isLoadingGalleries, setIsLoadingGalleries] = useState(false);
const [galleries, setGalleries] = useState<Gallery[]>([]);
const { user } = useContext(appContext);
// Fetch galleries
const fetchGalleries = async () => {
if (!user?.id) return;
setIsLoadingGalleries(true);
try {
const galleryAPI = new GalleryAPI();
const data = await galleryAPI.listGalleries(user.id);
setGalleries(data);
// Check localStorage for a previously saved gallery ID
const savedGalleryId = getLocalStorage(`selectedGalleryId_${user.id}`);
if (savedGalleryId && data.length > 0) {
const savedGallery = data.find((g) => g.id === savedGalleryId);
if (savedGallery) {
setSelectedGallery(savedGallery);
} else if (!selectedGallery && data.length > 0) {
setSelectedGallery(data[0]);
}
} else if (!selectedGallery && data.length > 0) {
setSelectedGallery(data[0]);
}
} catch (error) {
console.error("Error fetching galleries:", error);
} finally {
setIsLoadingGalleries(false);
}
};
// Fetch galleries on mount
React.useEffect(() => {
fetchGalleries();
}, [user?.id]);
// Render collapsed state
if (!isOpen) {
@ -262,13 +299,28 @@ export const TeamSidebar: React.FC<TeamSidebarProps> = ({
{activeTab === "gallery" && (
<div className="p-2">
{/* Gallery Selector */}
<div className="my-2 mb-3 text-xs">
{" "}
Select a{" "}
<Link to="/gallery" className="text-accent">
<span className="font-medium">gallery</span>
</Link>{" "}
to view its components as templates
</div>
<Select
className="w-full mb-4"
placeholder="Select gallery"
value={selectedGallery?.id}
onChange={(value) => {
const gallery = galleries.find((g) => g.id === value);
if (gallery) selectGallery(gallery);
if (gallery) {
setSelectedGallery(gallery);
// Save the selected gallery ID to localStorage
if (user?.id) {
setLocalStorage(`selectedGalleryId_${user.id}`, value);
}
}
}}
options={galleries.map((gallery) => ({
value: gallery.id,

View File

@ -0,0 +1,29 @@
import * as React from "react";
import Layout from "../components/layout";
import { graphql } from "gatsby";
import DeployManager from "../components/views/deploy/manager";
import LabsManager from "../components/views/labs/manager";
// markup
const LabsPage = ({ data }: any) => {
return (
<Layout meta={data.site.siteMetadata} title="Home" link={"/labs"}>
<main style={{ height: "100%" }} className=" h-full ">
<LabsManager />
</main>
</Layout>
);
};
export const query = graphql`
query HomePageQuery {
site {
siteMetadata {
description
title
}
}
}
`;
export default LabsPage;

View File

@ -381,3 +381,12 @@ div#gatsby-focus-wrapper {
height: 100%;
/* border: 1px solid green; */
}
.zero-padding-upload.ant-upload,
.zero-padding-upload .ant-upload,
.zero-padding-upload .ant-upload-select,
.zero-padding-upload .ant-btn {
padding: 0 !important;
margin: 0 !important;
border: none !important;
}

View File

@ -54,7 +54,7 @@ class TestDatabaseOperations:
def test_basic_setup(self, test_db: DatabaseManager):
"""Test basic database setup and connection"""
with Session(test_db.engine) as session:
result = session.exec(text("SELECT 1")).first()
result = session.exec(text("SELECT 1")).first() # type: ignore
assert result[0] == 1
result = session.exec(select(1)).first()
assert result == 1
@ -85,7 +85,7 @@ class TestDatabaseOperations:
# Verify Update
result = test_db.get(Team, {"id": team_id})
assert result.status is True
assert result.data[0].version == "0.0.2"
assert result.data and result.data[0].version == "0.0.2"
def test_delete_operations(self, test_db: DatabaseManager, sample_team: Team):
"""Test delete with various filters"""
@ -103,7 +103,8 @@ class TestDatabaseOperations:
# Verify deletion
result = test_db.get(Team, {"id": team_id})
assert len(result.data) == 0
if result.data:
assert len(result.data) == 0
def test_cascade_delete(self, test_db: DatabaseManager, test_user: str):
"""Test all levels of cascade delete"""
@ -133,7 +134,9 @@ class TestDatabaseOperations:
))
test_db.delete(Run, {"id": run1_id})
assert len(test_db.get(Message, {"run_id": run1_id}).data) == 0, "Run->Message cascade failed"
db_message = test_db.get(Message, {"run_id": run1_id})
if db_message.data:
assert len(db_message.data) == 0, "Run->Message cascade failed"
# Test Session -> Run -> Message cascade
session2 = SessionModel(user_id=test_user, team_id=team1.id, name="Session2")
@ -154,8 +157,12 @@ class TestDatabaseOperations:
))
test_db.delete(SessionModel, {"id": session2.id})
assert len(test_db.get(Run, {"session_id": session2.id}).data) == 0, "Session->Run cascade failed"
assert len(test_db.get(Message, {"run_id": run2_id}).data) == 0, "Session->Run->Message cascade failed"
session = test_db.get(SessionModel, {"id": session2.id})
run = test_db.get(Run, {"id": run2_id})
if session.data:
assert len(session.data) == 0, "Session->Run cascade failed"
if run.data:
assert len(run.data) == 0, "Session->Run->Message cascade failed"
# Clean up
test_db.delete(Team, {"id": team1.id})

View File

@ -808,7 +808,7 @@ requires-dist = [
[[package]]
name = "autogenstudio"
version = "0.4.1"
version = "0.4.2"
source = { editable = "packages/autogen-studio" }
dependencies = [
{ name = "aiofiles" },