mirror of https://github.com/microsoft/autogen.git
[Draft] Enable File Upload/Paste as Task in AGS (#6091)
<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? https://github.com/user-attachments/assets/e160f16d-f42d-49e2-a6c6-687e4e6786f4 Enable file upload/paste as a task in AGS. Enables tasks like - Can you research and fact check the ideas in this screenshot? - Summarize this file Only text and images supported for now Underneath, it constructs TextMessage and Multimodal messages as the task. <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> Closes #5773 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
This commit is contained in:
parent
cc806a57ef
commit
32d2a18bf1
|
@ -8,7 +8,7 @@ from loguru import logger
|
||||||
from sqlalchemy import exc, inspect, text
|
from sqlalchemy import exc, inspect, text
|
||||||
from sqlmodel import Session, SQLModel, and_, create_engine, select
|
from sqlmodel import Session, SQLModel, and_, create_engine, select
|
||||||
|
|
||||||
from ..datamodel import Response, Team
|
from ..datamodel import BaseDBModel, Response, Team
|
||||||
from ..teammanager import TeamManager
|
from ..teammanager import TeamManager
|
||||||
from .schema_manager import SchemaManager
|
from .schema_manager import SchemaManager
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ class DatabaseManager:
|
||||||
finally:
|
finally:
|
||||||
self._init_lock.release()
|
self._init_lock.release()
|
||||||
|
|
||||||
def reset_db(self, recreate_tables: bool = True):
|
def reset_db(self, recreate_tables: bool = True) -> Response:
|
||||||
"""
|
"""
|
||||||
Reset the database by dropping all tables and optionally recreating them.
|
Reset the database by dropping all tables and optionally recreating them.
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ class DatabaseManager:
|
||||||
self._init_lock.release()
|
self._init_lock.release()
|
||||||
logger.info("Database reset lock released")
|
logger.info("Database reset lock released")
|
||||||
|
|
||||||
def upsert(self, model: SQLModel, return_json: bool = True) -> Response:
|
def upsert(self, model: BaseDBModel, return_json: bool = True) -> Response:
|
||||||
"""Create or update an entity
|
"""Create or update an entity
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -199,7 +199,7 @@ class DatabaseManager:
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
model_class: SQLModel,
|
model_class: type[BaseDBModel],
|
||||||
filters: dict | None = None,
|
filters: dict | None = None,
|
||||||
return_json: bool = False,
|
return_json: bool = False,
|
||||||
order: str = "desc",
|
order: str = "desc",
|
||||||
|
@ -211,7 +211,7 @@ class DatabaseManager:
|
||||||
status_message = ""
|
status_message = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
statement = select(model_class)
|
statement = select(model_class) # type: ignore
|
||||||
if filters:
|
if filters:
|
||||||
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
||||||
statement = statement.where(and_(*conditions))
|
statement = statement.where(and_(*conditions))
|
||||||
|
@ -231,7 +231,7 @@ class DatabaseManager:
|
||||||
|
|
||||||
return Response(message=status_message, status=status, data=result)
|
return Response(message=status_message, status=status, data=result)
|
||||||
|
|
||||||
def delete(self, model_class: SQLModel, filters: dict = None) -> Response:
|
def delete(self, model_class: type[BaseDBModel], filters: dict | None = None) -> Response:
|
||||||
"""Delete an entity"""
|
"""Delete an entity"""
|
||||||
status_message = ""
|
status_message = ""
|
||||||
status = True
|
status = True
|
||||||
|
@ -239,8 +239,8 @@ class DatabaseManager:
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
try:
|
try:
|
||||||
if "sqlite" in str(self.engine.url):
|
if "sqlite" in str(self.engine.url):
|
||||||
session.exec(text("PRAGMA foreign_keys=ON"))
|
session.exec(text("PRAGMA foreign_keys=ON")) # type: ignore
|
||||||
statement = select(model_class)
|
statement = select(model_class) # type: ignore
|
||||||
if filters:
|
if filters:
|
||||||
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
||||||
statement = statement.where(and_(*conditions))
|
statement = statement.where(and_(*conditions))
|
||||||
|
@ -326,7 +326,7 @@ class DatabaseManager:
|
||||||
{
|
{
|
||||||
"status": result.status,
|
"status": result.status,
|
||||||
"message": result.message,
|
"message": result.message,
|
||||||
"id": result.data.get("id") if result.data else None,
|
"id": result.data.get("id") if result.data and result.data is not None else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -342,7 +342,8 @@ class DatabaseManager:
|
||||||
|
|
||||||
async def _check_team_exists(self, config: dict, user_id: str) -> Optional[Team]:
|
async def _check_team_exists(self, config: dict, user_id: str) -> Optional[Team]:
|
||||||
"""Check if identical team config already exists"""
|
"""Check if identical team config already exists"""
|
||||||
teams = self.get(Team, {"user_id": user_id}).data
|
response = self.get(Team, {"user_id": user_id})
|
||||||
|
teams = response.data if response.status and response.data is not None else []
|
||||||
|
|
||||||
for team in teams:
|
for team in teams:
|
||||||
if team.component == config:
|
if team.component == config:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .db import Gallery, Message, Run, RunStatus, Session, Settings, Team
|
from .db import BaseDBModel, Gallery, Message, Run, RunStatus, Session, Settings, Team
|
||||||
from .types import (
|
from .types import (
|
||||||
EnvironmentVariable,
|
EnvironmentVariable,
|
||||||
GalleryComponents,
|
GalleryComponents,
|
||||||
|
|
|
@ -2,13 +2,14 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from autogen_core import ComponentModel
|
from autogen_core import ComponentModel
|
||||||
from pydantic import ConfigDict, SecretStr
|
from pydantic import ConfigDict, SecretStr, field_validator
|
||||||
from sqlalchemy import ForeignKey, Integer, String
|
from sqlalchemy import ForeignKey, Integer
|
||||||
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func
|
from sqlmodel import JSON, Column, DateTime, Field, SQLModel, func
|
||||||
|
|
||||||
|
from .eval import EvalJudgeCriteria, EvalRunResult, EvalRunStatus, EvalScore, EvalTask
|
||||||
from .types import (
|
from .types import (
|
||||||
GalleryComponents,
|
GalleryComponents,
|
||||||
GalleryConfig,
|
GalleryConfig,
|
||||||
|
@ -20,35 +21,41 @@ from .types import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Team(SQLModel, table=True):
|
class BaseDBModel(SQLModel, table=False):
|
||||||
__table_args__ = {"sqlite_autoincrement": True}
|
"""
|
||||||
|
Base model with common fields for all database tables.
|
||||||
|
Not a table itself - meant to be inherited by concrete model classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
# Common fields present in all database tables
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
id: Optional[int] = Field(default=None, primary_key=True)
|
||||||
|
|
||||||
created_at: datetime = Field(
|
created_at: datetime = Field(
|
||||||
default_factory=datetime.now,
|
default_factory=datetime.now,
|
||||||
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
sa_type=DateTime(timezone=True), # type: ignore[assignment]
|
||||||
) # pylint: disable=not-callable
|
sa_column_kwargs={"server_default": func.now(), "nullable": True},
|
||||||
|
)
|
||||||
|
|
||||||
updated_at: datetime = Field(
|
updated_at: datetime = Field(
|
||||||
default_factory=datetime.now,
|
default_factory=datetime.now,
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
sa_type=DateTime(timezone=True), # type: ignore[assignment]
|
||||||
) # pylint: disable=not-callable
|
sa_column_kwargs={"onupdate": func.now(), "nullable": True},
|
||||||
|
)
|
||||||
|
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
version: Optional[str] = "0.0.1"
|
version: Optional[str] = "0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
class Team(BaseDBModel, table=True):
|
||||||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
|
component: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
|
||||||
|
|
||||||
|
|
||||||
class Message(SQLModel, table=True):
|
class Message(BaseDBModel, table=True):
|
||||||
__table_args__ = {"sqlite_autoincrement": True}
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
created_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
version: Optional[str] = "0.0.1"
|
|
||||||
config: Union[MessageConfig, dict] = Field(
|
config: Union[MessageConfig, dict] = Field(
|
||||||
default_factory=lambda: MessageConfig(source="", content=""), sa_column=Column(JSON)
|
default_factory=lambda: MessageConfig(source="", content=""), sa_column=Column(JSON)
|
||||||
)
|
)
|
||||||
|
@ -60,22 +67,18 @@ class Message(SQLModel, table=True):
|
||||||
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
|
message_meta: Optional[Union[MessageMeta, dict]] = Field(default={}, sa_column=Column(JSON))
|
||||||
|
|
||||||
|
|
||||||
class Session(SQLModel, table=True):
|
class Session(BaseDBModel, table=True):
|
||||||
__table_args__ = {"sqlite_autoincrement": True}
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
created_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
version: Optional[str] = "0.0.1"
|
|
||||||
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
|
team_id: Optional[int] = Field(default=None, sa_column=Column(Integer, ForeignKey("team.id", ondelete="CASCADE")))
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator("created_at", "updated_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_datetime(cls, value: Union[str, datetime]) -> datetime:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class RunStatus(str, Enum):
|
class RunStatus(str, Enum):
|
||||||
CREATED = "created"
|
CREATED = "created"
|
||||||
|
@ -85,18 +88,11 @@ class RunStatus(str, Enum):
|
||||||
STOPPED = "stopped"
|
STOPPED = "stopped"
|
||||||
|
|
||||||
|
|
||||||
class Run(SQLModel, table=True):
|
class Run(BaseDBModel, table=True):
|
||||||
"""Represents a single execution run within a session"""
|
"""Represents a single execution run within a session"""
|
||||||
|
|
||||||
__table_args__ = {"sqlite_autoincrement": True}
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
created_at: datetime = Field(
|
|
||||||
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
)
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=datetime.now, sa_column=Column(DateTime(timezone=True), onupdate=func.now())
|
|
||||||
)
|
|
||||||
session_id: Optional[int] = Field(
|
session_id: Optional[int] = Field(
|
||||||
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
|
default=None, sa_column=Column(Integer, ForeignKey("session.id", ondelete="CASCADE"), nullable=False)
|
||||||
)
|
)
|
||||||
|
@ -118,19 +114,9 @@ class Run(SQLModel, table=True):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Gallery(SQLModel, table=True):
|
class Gallery(BaseDBModel, table=True):
|
||||||
__table_args__ = {"sqlite_autoincrement": True}
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
created_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
version: Optional[str] = "0.0.1"
|
|
||||||
config: Union[GalleryConfig, dict] = Field(
|
config: Union[GalleryConfig, dict] = Field(
|
||||||
default_factory=lambda: GalleryConfig(
|
default_factory=lambda: GalleryConfig(
|
||||||
id="",
|
id="",
|
||||||
|
@ -149,17 +135,64 @@ class Gallery(SQLModel, table=True):
|
||||||
) # type: ignore[call-arg]
|
) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
|
||||||
class Settings(SQLModel, table=True):
|
class Settings(BaseDBModel, table=True):
|
||||||
__table_args__ = {"sqlite_autoincrement": True}
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
created_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=datetime.now,
|
|
||||||
sa_column=Column(DateTime(timezone=True), onupdate=func.now()),
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
version: Optional[str] = "0.0.1"
|
|
||||||
config: Union[SettingsConfig, dict] = Field(default_factory=SettingsConfig, sa_column=Column(JSON))
|
config: Union[SettingsConfig, dict] = Field(default_factory=SettingsConfig, sa_column=Column(JSON))
|
||||||
|
|
||||||
|
|
||||||
|
# --- Evaluation system database models ---
|
||||||
|
|
||||||
|
|
||||||
|
class EvalTaskDB(BaseDBModel, table=True):
|
||||||
|
"""Database model for storing evaluation tasks."""
|
||||||
|
|
||||||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
|
|
||||||
|
name: str = "Unnamed Task"
|
||||||
|
description: str = ""
|
||||||
|
config: Union[EvalTask, dict] = Field(sa_column=Column(JSON))
|
||||||
|
|
||||||
|
|
||||||
|
class EvalCriteriaDB(BaseDBModel, table=True):
|
||||||
|
"""Database model for storing evaluation criteria."""
|
||||||
|
|
||||||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
|
|
||||||
|
name: str = "Unnamed Criteria"
|
||||||
|
description: str = ""
|
||||||
|
config: Union[EvalJudgeCriteria, dict] = Field(sa_column=Column(JSON))
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRunDB(BaseDBModel, table=True):
|
||||||
|
"""Database model for tracking evaluation runs."""
|
||||||
|
|
||||||
|
__table_args__ = {"sqlite_autoincrement": True}
|
||||||
|
|
||||||
|
name: str = "Unnamed Evaluation Run"
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
# References to related components
|
||||||
|
task_id: Optional[int] = Field(
|
||||||
|
default=None, sa_column=Column(Integer, ForeignKey("evaltaskdb.id", ondelete="SET NULL"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Serialized configurations for runner and judge
|
||||||
|
runner_config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
|
||||||
|
judge_config: Union[ComponentModel, dict] = Field(sa_column=Column(JSON))
|
||||||
|
|
||||||
|
# List of criteria IDs or embedded criteria configs
|
||||||
|
criteria_configs: List[Union[EvalJudgeCriteria, dict]] = Field(default_factory=list, sa_column=Column(JSON))
|
||||||
|
|
||||||
|
# Run status and timing information
|
||||||
|
status: EvalRunStatus = Field(default=EvalRunStatus.PENDING)
|
||||||
|
start_time: Optional[datetime] = Field(default=None)
|
||||||
|
end_time: Optional[datetime] = Field(default=None)
|
||||||
|
|
||||||
|
# Results (updated as they become available)
|
||||||
|
run_result: Union[EvalRunResult, dict] = Field(default=None, sa_column=Column(JSON))
|
||||||
|
|
||||||
|
score_result: Union[EvalScore, dict] = Field(default=None, sa_column=Column(JSON))
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
# datamodel/eval.py
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from autogen_agentchat.base import TaskResult
|
||||||
|
from autogen_core import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlmodel import Field
|
||||||
|
|
||||||
|
|
||||||
|
class EvalTask(BaseModel):
|
||||||
|
"""Definition of a task to be evaluated."""
|
||||||
|
|
||||||
|
task_id: UUID | str = Field(default_factory=uuid4)
|
||||||
|
input: str | Sequence[str | Image]
|
||||||
|
name: str = ""
|
||||||
|
description: str = ""
|
||||||
|
expected_outputs: Optional[List[Any]] = None
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRunResult(BaseModel):
|
||||||
|
"""Result of an evaluation run."""
|
||||||
|
|
||||||
|
result: TaskResult | None = None
|
||||||
|
status: bool = False
|
||||||
|
start_time: Optional[datetime] = Field(default=datetime.now())
|
||||||
|
end_time: Optional[datetime] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EvalDimensionScore(BaseModel):
|
||||||
|
"""Score for a single evaluation dimension."""
|
||||||
|
|
||||||
|
dimension: str
|
||||||
|
score: float
|
||||||
|
reason: str
|
||||||
|
max_value: float
|
||||||
|
min_value: float
|
||||||
|
|
||||||
|
|
||||||
|
class EvalScore(BaseModel):
|
||||||
|
"""Composite score from evaluation."""
|
||||||
|
|
||||||
|
overall_score: Optional[float] = None
|
||||||
|
dimension_scores: List[EvalDimensionScore] = []
|
||||||
|
reason: Optional[str] = None
|
||||||
|
max_value: float = 10.0
|
||||||
|
min_value: float = 0.0
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class EvalJudgeCriteria(BaseModel):
|
||||||
|
"""Criteria for judging evaluation results."""
|
||||||
|
|
||||||
|
dimension: str
|
||||||
|
prompt: str
|
||||||
|
max_value: float = 10.0
|
||||||
|
min_value: float = 0.0
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRunStatus(str, Enum):
|
||||||
|
"""Status of an evaluation run."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELED = "canceled"
|
||||||
|
|
||||||
|
|
||||||
|
class EvalResult(BaseModel):
|
||||||
|
"""Result of an evaluation run."""
|
||||||
|
|
||||||
|
task_id: UUID | str
|
||||||
|
# runner_id: UUID | str
|
||||||
|
status: EvalRunStatus = EvalRunStatus.PENDING
|
||||||
|
start_time: Optional[datetime] = Field(default=datetime.now())
|
||||||
|
end_time: Optional[datetime] = None
|
|
@ -1,9 +1,9 @@
|
||||||
# from dataclasses import Field
|
# from dataclasses import Field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
from autogen_agentchat.base import TaskResult
|
from autogen_agentchat.base import TaskResult
|
||||||
from autogen_agentchat.messages import BaseChatMessage
|
from autogen_agentchat.messages import ChatMessage, TextMessage
|
||||||
from autogen_core import ComponentModel
|
from autogen_core import ComponentModel
|
||||||
from autogen_core.models import UserMessage
|
from autogen_core.models import UserMessage
|
||||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||||
|
@ -12,7 +12,7 @@ from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
class MessageConfig(BaseModel):
|
class MessageConfig(BaseModel):
|
||||||
source: str
|
source: str
|
||||||
content: str
|
content: str | ChatMessage | Sequence[ChatMessage] | None
|
||||||
message_type: Optional[str] = "text"
|
message_type: Optional[str] = "text"
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,9 +22,8 @@ class TeamResult(BaseModel):
|
||||||
duration: float
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
class LLMCallEventMessage(BaseChatMessage):
|
class LLMCallEventMessage(TextMessage):
|
||||||
source: str = "llm_call_event"
|
source: str = "llm_call_event"
|
||||||
content: str
|
|
||||||
|
|
||||||
def to_text(self) -> str:
|
def to_text(self) -> str:
|
||||||
return self.content
|
return self.content
|
||||||
|
|
|
@ -0,0 +1,267 @@
|
||||||
|
import asyncio
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from autogen_core import CancellationToken, Component, ComponentBase
|
||||||
|
from autogen_core.models import ChatCompletionClient, UserMessage
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ..datamodel.eval import EvalDimensionScore, EvalJudgeCriteria, EvalRunResult, EvalScore, EvalTask
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEvalJudgeConfig(BaseModel):
|
||||||
|
"""Base configuration for evaluation judges."""
|
||||||
|
|
||||||
|
name: str = "Base Judge"
|
||||||
|
description: str = ""
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEvalJudge(ABC, ComponentBase[BaseEvalJudgeConfig]):
|
||||||
|
"""Abstract base class for evaluation judges."""
|
||||||
|
|
||||||
|
component_type = "eval_judge"
|
||||||
|
|
||||||
|
def __init__(self, name: str = "Base Judge", description: str = "", metadata: Optional[Dict[str, Any]] = None):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def judge(
|
||||||
|
self,
|
||||||
|
task: EvalTask,
|
||||||
|
result: EvalRunResult,
|
||||||
|
criteria: List[EvalJudgeCriteria],
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
) -> EvalScore:
|
||||||
|
"""Judge the result of an evaluation run."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _to_config(self) -> BaseEvalJudgeConfig:
|
||||||
|
"""Convert the judge configuration to a configuration object for serialization."""
|
||||||
|
return BaseEvalJudgeConfig(name=self.name, description=self.description, metadata=self.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEvalJudgeConfig(BaseEvalJudgeConfig):
|
||||||
|
"""Configuration for LLMEvalJudge."""
|
||||||
|
|
||||||
|
model_client: Any # ComponentModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEvalJudge(BaseEvalJudge, Component[LLMEvalJudgeConfig]):
|
||||||
|
"""Judge that uses an LLM to evaluate results."""
|
||||||
|
|
||||||
|
component_config_schema = LLMEvalJudgeConfig
|
||||||
|
component_type = "eval_judge"
|
||||||
|
component_provider_override = "autogenstudio.eval.judges.LLMEvalJudge"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_client: ChatCompletionClient,
|
||||||
|
name: str = "LLM Judge",
|
||||||
|
description: str = "Evaluates results using an LLM",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, description, metadata)
|
||||||
|
self.model_client = model_client
|
||||||
|
|
||||||
|
async def judge(
|
||||||
|
self,
|
||||||
|
task: EvalTask,
|
||||||
|
result: EvalRunResult,
|
||||||
|
criteria: List[EvalJudgeCriteria],
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
) -> EvalScore:
|
||||||
|
"""Judge the result using an LLM."""
|
||||||
|
# Create a score object
|
||||||
|
score = EvalScore(max_value=10.0)
|
||||||
|
|
||||||
|
# Judge each dimension in parallel
|
||||||
|
dimension_score_tasks = []
|
||||||
|
for criterion in criteria:
|
||||||
|
dimension_score_tasks.append(self._judge_dimension(task, result, criterion, cancellation_token))
|
||||||
|
|
||||||
|
dimension_scores = await asyncio.gather(*dimension_score_tasks)
|
||||||
|
score.dimension_scores = dimension_scores
|
||||||
|
|
||||||
|
# Calculate overall score (average of dimension scores)
|
||||||
|
valid_scores = [ds.score for ds in dimension_scores if ds.score is not None]
|
||||||
|
if valid_scores:
|
||||||
|
score.overall_score = sum(valid_scores) / len(valid_scores)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
async def _judge_dimension(
|
||||||
|
self,
|
||||||
|
task: EvalTask,
|
||||||
|
result: EvalRunResult,
|
||||||
|
criterion: EvalJudgeCriteria,
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
) -> EvalDimensionScore:
|
||||||
|
"""Judge a specific dimension."""
|
||||||
|
# Format task and result for the LLM
|
||||||
|
task_description = self._format_task(task)
|
||||||
|
result_description = result.model_dump()
|
||||||
|
|
||||||
|
# Create the prompt
|
||||||
|
prompt = f"""
|
||||||
|
You are evaluating the quality of a system response to a task.
|
||||||
|
Task: {task_description}Response: {result_description}
|
||||||
|
Evaluation criteria: {criterion.dimension}
|
||||||
|
{criterion.prompt}
|
||||||
|
Score the response on a scale from {criterion.min_value} to {criterion.max_value}.
|
||||||
|
First, provide a detailed explanation of your evaluation.
|
||||||
|
Then, give your final score as a single number between 0 and {criterion.max_value}.
|
||||||
|
Format your answer should be a json for the EvalDimensionScore class:
|
||||||
|
{{
|
||||||
|
"dimension": "{criterion.dimension}",
|
||||||
|
"reason": "<explanation>",
|
||||||
|
"score": <score>
|
||||||
|
}}
|
||||||
|
Please ensure the score is a number between {criterion.min_value} and {criterion.max_value}.
|
||||||
|
If you cannot evaluate the response, please return a score of null.
|
||||||
|
If the response is not relevant, please return a score of 0.
|
||||||
|
If the response is perfect, please return a score of {criterion.max_value}.
|
||||||
|
If the response is not relevant, please return a score of 0.
|
||||||
|
If the response is perfect, please return a score of {criterion.max_value}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get judgment from LLM
|
||||||
|
model_input = []
|
||||||
|
text_message = UserMessage(content=prompt, source="user")
|
||||||
|
model_input.append(text_message)
|
||||||
|
|
||||||
|
# Run with the model client in the same format as used in runners
|
||||||
|
model_result = await self.model_client.create(
|
||||||
|
messages=model_input,
|
||||||
|
cancellation_token=cancellation_token,
|
||||||
|
json_output=EvalDimensionScore,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract content from the response
|
||||||
|
model_response = model_result.content if isinstance(model_result.content, str) else str(model_result.content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# validate response string as EvalDimensionScore
|
||||||
|
model_response = EvalDimensionScore.model_validate_json(model_response)
|
||||||
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse LLM response: {e}", model_result.content)
|
||||||
|
return EvalDimensionScore(
|
||||||
|
dimension=criterion.dimension,
|
||||||
|
reason="Failed to parse response",
|
||||||
|
score=0.0,
|
||||||
|
max_value=criterion.max_value,
|
||||||
|
min_value=criterion.min_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_task(self, task: EvalTask) -> str:
|
||||||
|
"""Format the task for the LLM."""
|
||||||
|
task_parts = []
|
||||||
|
|
||||||
|
if task.description:
|
||||||
|
task_parts.append(task.description)
|
||||||
|
if isinstance(task.input, str):
|
||||||
|
task_parts.append(task.input)
|
||||||
|
elif isinstance(task.input, list):
|
||||||
|
task_parts.append("\n".join(str(x) for x in task.input if isinstance(x, str)))
|
||||||
|
|
||||||
|
return "\n".join(task_parts)
|
||||||
|
|
||||||
|
def _parse_judgment(self, judgment_text: str, max_value: float) -> Tuple[str, Optional[float]]:
|
||||||
|
"""Parse judgment text to extract explanation and score."""
|
||||||
|
explanation = ""
|
||||||
|
score = None
|
||||||
|
|
||||||
|
# Simple parsing - could be improved with regex
|
||||||
|
lines = judgment_text.split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().lower().startswith("explanation:"):
|
||||||
|
explanation = line.split(":", 1)[1].strip()
|
||||||
|
elif line.strip().lower().startswith("score:"):
|
||||||
|
try:
|
||||||
|
score_str = line.split(":", 1)[1].strip()
|
||||||
|
score = float(score_str)
|
||||||
|
# Ensure score is within bounds
|
||||||
|
score = min(max(score, 0), max_value)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return explanation, score
|
||||||
|
|
||||||
|
def _to_config(self) -> LLMEvalJudgeConfig:
|
||||||
|
"""Convert to configuration object including model client configuration."""
|
||||||
|
base_config = super()._to_config()
|
||||||
|
return LLMEvalJudgeConfig(
|
||||||
|
name=base_config.name,
|
||||||
|
description=base_config.description,
|
||||||
|
metadata=base_config.metadata,
|
||||||
|
model_client=self.model_client.dump_component(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config: LLMEvalJudgeConfig) -> Self:
|
||||||
|
"""Create from configuration object with serialized model client."""
|
||||||
|
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||||
|
return cls(
|
||||||
|
model_client=model_client, name=config.name, description=config.description, metadata=config.metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# # Usage example
|
||||||
|
# async def example_usage():
|
||||||
|
# # Create a model client
|
||||||
|
# from autogen_ext.models import OpenAIChatCompletionClient
|
||||||
|
|
||||||
|
# model_client = OpenAIChatCompletionClient(
|
||||||
|
# model="gpt-4",
|
||||||
|
# api_key="your-api-key"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Create a judge
|
||||||
|
# llm_judge = LLMEvalJudge(model_client=model_client)
|
||||||
|
|
||||||
|
# # Serialize the judge to a ComponentModel
|
||||||
|
# judge_config = llm_judge.dump_component()
|
||||||
|
# print(f"Serialized judge: {judge_config}")
|
||||||
|
|
||||||
|
# # Deserialize back to a LLMEvalJudge
|
||||||
|
# deserialized_judge = LLMEvalJudge.load_component(judge_config)
|
||||||
|
|
||||||
|
# # Create criteria for evaluation
|
||||||
|
# criteria = [
|
||||||
|
# EvalJudgeCriteria(
|
||||||
|
# dimension="relevance",
|
||||||
|
# prompt="Evaluate how relevant the response is to the query.",
|
||||||
|
# min_value=0,
|
||||||
|
# max_value=10
|
||||||
|
# ),
|
||||||
|
# EvalJudgeCriteria(
|
||||||
|
# dimension="accuracy",
|
||||||
|
# prompt="Evaluate the factual accuracy of the response.",
|
||||||
|
# min_value=0,
|
||||||
|
# max_value=10
|
||||||
|
# )
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# # Create a mock task and result
|
||||||
|
# task = EvalTask(
|
||||||
|
# id="task-123",
|
||||||
|
# name="Sample Task",
|
||||||
|
# description="A sample task for evaluation",
|
||||||
|
# input="What is the capital of France?"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# result = EvalRunResult(
|
||||||
|
# status=True,
|
||||||
|
# result={
|
||||||
|
# "messages": [{"content": "The capital of France is Paris.", "source": "model"}]
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Run the evaluation
|
||||||
|
# score = await deserialized_judge.judge(task, result, criteria)
|
||||||
|
# print(f"Evaluation score: {score}")
|
|
@ -0,0 +1,789 @@
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pdb import run
|
||||||
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..database.db_manager import DatabaseManager
|
||||||
|
from ..datamodel.db import EvalCriteriaDB, EvalRunDB, EvalTaskDB
|
||||||
|
from ..datamodel.eval import EvalJudgeCriteria, EvalRunResult, EvalRunStatus, EvalScore, EvalTask
|
||||||
|
from .judges import BaseEvalJudge
|
||||||
|
from .runners import BaseEvalRunner
|
||||||
|
|
||||||
|
|
||||||
|
class DimensionScore(TypedDict):
|
||||||
|
score: Optional[float]
|
||||||
|
reason: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RunEntry(TypedDict):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
task_name: str
|
||||||
|
runner_type: str
|
||||||
|
overall_score: Optional[float]
|
||||||
|
scores: List[Optional[float]]
|
||||||
|
reasons: Optional[List[Optional[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
class TabulatedResults(TypedDict):
|
||||||
|
dimensions: List[str]
|
||||||
|
runs: List[RunEntry]
|
||||||
|
|
||||||
|
|
||||||
|
class EvalOrchestrator:
|
||||||
|
"""
|
||||||
|
Orchestrator for evaluation runs.
|
||||||
|
|
||||||
|
This class manages the lifecycle of evaluation tasks, criteria, and runs.
|
||||||
|
It can operate with or without a database manager for persistence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_manager: Optional[DatabaseManager] = None):
|
||||||
|
"""
|
||||||
|
Initialize the orchestrator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_manager: Optional database manager for persistence.
|
||||||
|
If None, data is stored in memory only.
|
||||||
|
"""
|
||||||
|
self._db_manager = db_manager
|
||||||
|
|
||||||
|
# In-memory storage (used when db_manager is None)
|
||||||
|
self._tasks: Dict[str, EvalTask] = {}
|
||||||
|
self._criteria: Dict[str, EvalJudgeCriteria] = {}
|
||||||
|
self._runs: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# Active runs tracking
|
||||||
|
self._active_runs: Dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
# ----- Task Management -----
|
||||||
|
|
||||||
|
async def create_task(self, task: EvalTask) -> str:
|
||||||
|
"""
|
||||||
|
Create a new evaluation task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The evaluation task to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task ID
|
||||||
|
"""
|
||||||
|
if not task.task_id:
|
||||||
|
task.task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if self._db_manager:
|
||||||
|
# Store in database
|
||||||
|
task_db = EvalTaskDB(name=task.name, description=task.description, config=task)
|
||||||
|
response = self._db_manager.upsert(task_db)
|
||||||
|
if not response.status:
|
||||||
|
logger.error(f"Failed to store task: {response.message}")
|
||||||
|
raise RuntimeError(f"Failed to store task: {response.message}")
|
||||||
|
task_id = str(response.data.get("id")) if response.data else str(task.task_id)
|
||||||
|
else:
|
||||||
|
# Store in memory
|
||||||
|
task_id = str(task.task_id)
|
||||||
|
self._tasks[task_id] = task
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
async def get_task(self, task_id: str) -> Optional[EvalTask]:
|
||||||
|
"""
|
||||||
|
Retrieve an evaluation task by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The task if found, None otherwise
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Retrieve from database
|
||||||
|
response = self._db_manager.get(EvalTaskDB, filters={"id": int(task_id) if task_id.isdigit() else task_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
task_data = response.data[0]
|
||||||
|
return (
|
||||||
|
task_data.get("config")
|
||||||
|
if isinstance(task_data.get("config"), EvalTask)
|
||||||
|
else EvalTask.model_validate(task_data.get("config"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Retrieve from memory
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_tasks(self) -> List[EvalTask]:
|
||||||
|
"""
|
||||||
|
List all available evaluation tasks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of evaluation tasks
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Retrieve from database
|
||||||
|
response = self._db_manager.get(EvalTaskDB)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
if response.status and response.data:
|
||||||
|
for task_data in response.data:
|
||||||
|
config = task_data.get("config")
|
||||||
|
if config:
|
||||||
|
if isinstance(config, EvalTask):
|
||||||
|
tasks.append(config)
|
||||||
|
else:
|
||||||
|
tasks.append(EvalTask.model_validate(config))
|
||||||
|
return tasks
|
||||||
|
else:
|
||||||
|
# Retrieve from memory
|
||||||
|
return list(self._tasks.values())
|
||||||
|
|
||||||
|
# ----- Criteria Management -----
|
||||||
|
|
||||||
|
async def create_criteria(self, criteria: EvalJudgeCriteria) -> str:
|
||||||
|
"""
|
||||||
|
Create new evaluation criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
criteria: The evaluation criteria to create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Criteria ID
|
||||||
|
"""
|
||||||
|
criteria_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if self._db_manager:
|
||||||
|
# Store in database
|
||||||
|
criteria_db = EvalCriteriaDB(name=criteria.dimension, description=criteria.prompt, config=criteria)
|
||||||
|
response = self._db_manager.upsert(criteria_db)
|
||||||
|
if not response.status:
|
||||||
|
logger.error(f"Failed to store criteria: {response.message}")
|
||||||
|
raise RuntimeError(f"Failed to store criteria: {response.message}")
|
||||||
|
criteria_id = str(response.data.get("id")) if response.data else criteria_id
|
||||||
|
else:
|
||||||
|
# Store in memory
|
||||||
|
self._criteria[criteria_id] = criteria
|
||||||
|
|
||||||
|
return criteria_id
|
||||||
|
|
||||||
|
async def get_criteria(self, criteria_id: str) -> Optional[EvalJudgeCriteria]:
|
||||||
|
"""
|
||||||
|
Retrieve evaluation criteria by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
criteria_id: The ID of the criteria to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The criteria if found, None otherwise
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Retrieve from database
|
||||||
|
response = self._db_manager.get(
|
||||||
|
EvalCriteriaDB, filters={"id": int(criteria_id) if criteria_id.isdigit() else criteria_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
criteria_data = response.data[0]
|
||||||
|
return (
|
||||||
|
criteria_data.get("config")
|
||||||
|
if isinstance(criteria_data.get("config"), EvalJudgeCriteria)
|
||||||
|
else EvalJudgeCriteria.model_validate(criteria_data.get("config"))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Retrieve from memory
|
||||||
|
return self._criteria.get(criteria_id)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_criteria(self) -> List[EvalJudgeCriteria]:
|
||||||
|
"""
|
||||||
|
List all available evaluation criteria.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of evaluation criteria
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Retrieve from database
|
||||||
|
response = self._db_manager.get(EvalCriteriaDB)
|
||||||
|
|
||||||
|
criteria_list = []
|
||||||
|
if response.status and response.data:
|
||||||
|
for criteria_data in response.data:
|
||||||
|
config = criteria_data.get("config")
|
||||||
|
if config:
|
||||||
|
if isinstance(config, EvalJudgeCriteria):
|
||||||
|
criteria_list.append(config)
|
||||||
|
else:
|
||||||
|
criteria_list.append(EvalJudgeCriteria.model_validate(config))
|
||||||
|
return criteria_list
|
||||||
|
else:
|
||||||
|
# Retrieve from memory
|
||||||
|
return list(self._criteria.values())
|
||||||
|
|
||||||
|
# ----- Run Management -----
|
||||||
|
|
||||||
|
async def create_run(
|
||||||
|
self,
|
||||||
|
task: Union[str, EvalTask],
|
||||||
|
runner: BaseEvalRunner,
|
||||||
|
judge: BaseEvalJudge,
|
||||||
|
criteria: List[Union[str, EvalJudgeCriteria]],
|
||||||
|
name: str = "",
|
||||||
|
description: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create a new evaluation run configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task to evaluate (ID or task object)
|
||||||
|
runner: The runner to use for evaluation
|
||||||
|
judge: The judge to use for evaluation
|
||||||
|
criteria: List of criteria to use for evaluation (IDs or criteria objects)
|
||||||
|
name: Name for the run
|
||||||
|
description: Description for the run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Run ID
|
||||||
|
"""
|
||||||
|
# Resolve task
|
||||||
|
task_obj = None
|
||||||
|
if isinstance(task, str):
|
||||||
|
task_obj = await self.get_task(task)
|
||||||
|
if not task_obj:
|
||||||
|
raise ValueError(f"Task not found: {task}")
|
||||||
|
else:
|
||||||
|
task_obj = task
|
||||||
|
|
||||||
|
# Resolve criteria
|
||||||
|
criteria_objs = []
|
||||||
|
for criterion in criteria:
|
||||||
|
if isinstance(criterion, str):
|
||||||
|
criterion_obj = await self.get_criteria(criterion)
|
||||||
|
if not criterion_obj:
|
||||||
|
raise ValueError(f"Criteria not found: {criterion}")
|
||||||
|
criteria_objs.append(criterion_obj)
|
||||||
|
else:
|
||||||
|
criteria_objs.append(criterion)
|
||||||
|
|
||||||
|
# Generate run ID
|
||||||
|
run_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create run configuration
|
||||||
|
runner_config = runner.dump_component() if hasattr(runner, "dump_component") else runner._to_config()
|
||||||
|
judge_config = judge.dump_component() if hasattr(judge, "dump_component") else judge._to_config()
|
||||||
|
|
||||||
|
if self._db_manager:
|
||||||
|
# Store in database
|
||||||
|
run_db = EvalRunDB(
|
||||||
|
name=name or f"Run {run_id}",
|
||||||
|
description=description,
|
||||||
|
task_id=int(task) if isinstance(task, str) and task.isdigit() else None,
|
||||||
|
runner_config=runner_config.model_dump(),
|
||||||
|
judge_config=judge_config.model_dump(),
|
||||||
|
criteria_configs=criteria_objs,
|
||||||
|
status=EvalRunStatus.PENDING,
|
||||||
|
)
|
||||||
|
response = self._db_manager.upsert(run_db)
|
||||||
|
if not response.status:
|
||||||
|
logger.error(f"Failed to store run: {response.message}")
|
||||||
|
raise RuntimeError(f"Failed to store run: {response.message}")
|
||||||
|
run_id = str(response.data.get("id")) if response.data else run_id
|
||||||
|
else:
|
||||||
|
# Store in memory
|
||||||
|
self._runs[run_id] = {
|
||||||
|
"task": task_obj,
|
||||||
|
"runner_config": runner_config,
|
||||||
|
"judge_config": judge_config,
|
||||||
|
"criteria_configs": [c.model_dump() for c in criteria_objs],
|
||||||
|
"status": EvalRunStatus.PENDING,
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"run_result": None,
|
||||||
|
"score_result": None,
|
||||||
|
"name": name or f"Run {run_id}",
|
||||||
|
"description": description,
|
||||||
|
}
|
||||||
|
|
||||||
|
return run_id
|
||||||
|
|
||||||
|
async def start_run(self, run_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Start an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run to start
|
||||||
|
"""
|
||||||
|
# Check if run is already active
|
||||||
|
if run_id in self._active_runs:
|
||||||
|
logger.warning(f"Run {run_id} is already active")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start the run asynchronously
|
||||||
|
run_task = asyncio.create_task(self._execute_run(run_id))
|
||||||
|
self._active_runs[run_id] = run_task
|
||||||
|
|
||||||
|
# Update run status
|
||||||
|
await self._update_run_status(run_id, EvalRunStatus.RUNNING)
|
||||||
|
|
||||||
|
async def _execute_run(self, run_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Execute an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run to execute
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get run configuration
|
||||||
|
run_config = await self._get_run_config(run_id)
|
||||||
|
if not run_config:
|
||||||
|
raise ValueError(f"Run not found: {run_id}")
|
||||||
|
|
||||||
|
# Get task
|
||||||
|
task = run_config.get("task")
|
||||||
|
if not task:
|
||||||
|
raise ValueError(f"Task not found for run: {run_id}")
|
||||||
|
|
||||||
|
# Initialize runner
|
||||||
|
runner_config = run_config.get("runner_config")
|
||||||
|
runner = BaseEvalRunner.load_component(runner_config) if runner_config else None
|
||||||
|
|
||||||
|
# Initialize judge
|
||||||
|
judge_config = run_config.get("judge_config")
|
||||||
|
judge = BaseEvalJudge.load_component(judge_config) if judge_config else None
|
||||||
|
|
||||||
|
if not runner or not judge:
|
||||||
|
raise ValueError(f"Runner or judge not found for run: {run_id}")
|
||||||
|
|
||||||
|
# Initialize criteria
|
||||||
|
criteria_configs = run_config.get("criteria_configs")
|
||||||
|
criteria = []
|
||||||
|
if criteria_configs:
|
||||||
|
criteria = [
|
||||||
|
EvalJudgeCriteria.model_validate(c) if not isinstance(c, EvalJudgeCriteria) else c
|
||||||
|
for c in criteria_configs
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute runner
|
||||||
|
logger.info(f"Starting runner for run {run_id}")
|
||||||
|
start_time = datetime.now()
|
||||||
|
run_result = await runner.run(task)
|
||||||
|
|
||||||
|
# Update run result
|
||||||
|
await self._update_run_result(run_id, run_result)
|
||||||
|
|
||||||
|
if not run_result.status:
|
||||||
|
logger.error(f"Runner failed for run {run_id}: {run_result.error}")
|
||||||
|
await self._update_run_status(run_id, EvalRunStatus.FAILED)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute judge
|
||||||
|
logger.info(f"Starting judge for run {run_id}")
|
||||||
|
score_result = await judge.judge(task, run_result, criteria)
|
||||||
|
|
||||||
|
# Update score result
|
||||||
|
await self._update_score_result(run_id, score_result)
|
||||||
|
|
||||||
|
# Update run status
|
||||||
|
end_time = datetime.now()
|
||||||
|
await self._update_run_completed(run_id, start_time, end_time)
|
||||||
|
|
||||||
|
logger.info(f"Run {run_id} completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error executing run {run_id}: {str(e)}")
|
||||||
|
await self._update_run_error(run_id, str(e))
|
||||||
|
finally:
|
||||||
|
# Remove from active runs
|
||||||
|
if run_id in self._active_runs:
|
||||||
|
del self._active_runs[run_id]
|
||||||
|
|
||||||
|
async def get_run_status(self, run_id: str) -> Optional[EvalRunStatus]:
|
||||||
|
"""
|
||||||
|
Get the status of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run status if found, None otherwise
|
||||||
|
"""
|
||||||
|
run_config = await self._get_run_config(run_id)
|
||||||
|
return run_config.get("status") if run_config else None
|
||||||
|
|
||||||
|
async def get_run_result(self, run_id: str) -> Optional[EvalRunResult]:
|
||||||
|
"""
|
||||||
|
Get the result of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run result if found, None otherwise
|
||||||
|
"""
|
||||||
|
run_config = await self._get_run_config(run_id)
|
||||||
|
if not run_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
run_result = run_config.get("run_result")
|
||||||
|
if not run_result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return run_result if isinstance(run_result, EvalRunResult) else EvalRunResult.model_validate(run_result)
|
||||||
|
|
||||||
|
async def get_run_score(self, run_id: str) -> Optional[EvalScore]:
|
||||||
|
"""
|
||||||
|
Get the score of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run score if found, None otherwise
|
||||||
|
"""
|
||||||
|
run_config = await self._get_run_config(run_id)
|
||||||
|
if not run_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
score_result = run_config.get("score_result")
|
||||||
|
if not score_result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return score_result if isinstance(score_result, EvalScore) else EvalScore.model_validate(score_result)
|
||||||
|
|
||||||
|
async def list_runs(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all available evaluation runs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of run configurations
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Retrieve from database
|
||||||
|
response = self._db_manager.get(EvalRunDB)
|
||||||
|
|
||||||
|
runs = []
|
||||||
|
if response.status and response.data:
|
||||||
|
for run_data in response.data:
|
||||||
|
runs.append(
|
||||||
|
{
|
||||||
|
"id": run_data.get("id"),
|
||||||
|
"name": run_data.get("name"),
|
||||||
|
"status": run_data.get("status"),
|
||||||
|
"created_at": run_data.get("created_at"),
|
||||||
|
"updated_at": run_data.get("updated_at"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return runs
|
||||||
|
else:
|
||||||
|
# Retrieve from memory
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": run_id,
|
||||||
|
"name": run_config.get("name"),
|
||||||
|
"status": run_config.get("status"),
|
||||||
|
"created_at": run_config.get("created_at"),
|
||||||
|
"updated_at": run_config.get("updated_at", run_config.get("created_at")),
|
||||||
|
}
|
||||||
|
for run_id, run_config in self._runs.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
async def cancel_run(self, run_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Cancel an active evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run to cancel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the run was cancelled, False otherwise
|
||||||
|
"""
|
||||||
|
# Check if run is active
|
||||||
|
if run_id not in self._active_runs:
|
||||||
|
logger.warning(f"Run {run_id} is not active")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Cancel the run task
|
||||||
|
try:
|
||||||
|
self._active_runs[run_id].cancel()
|
||||||
|
await self._update_run_status(run_id, EvalRunStatus.CANCELED)
|
||||||
|
del self._active_runs[run_id]
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ----- Helper Methods -----
|
||||||
|
|
||||||
|
async def _get_run_config(self, run_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get the configuration of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The run configuration if found, None otherwise
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Retrieve from database
|
||||||
|
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
run_data = response.data[0]
|
||||||
|
|
||||||
|
# Get task
|
||||||
|
task = None
|
||||||
|
if run_data.get("task_id"):
|
||||||
|
task_response = self._db_manager.get(EvalTaskDB, filters={"id": run_data.get("task_id")})
|
||||||
|
if task_response.status and task_response.data and len(task_response.data) > 0:
|
||||||
|
task_data = task_response.data[0]
|
||||||
|
task = (
|
||||||
|
task_data.get("config")
|
||||||
|
if isinstance(task_data.get("config"), EvalTask)
|
||||||
|
else EvalTask.model_validate(task_data.get("config"))
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task": task,
|
||||||
|
"runner_config": run_data.get("runner_config"),
|
||||||
|
"judge_config": run_data.get("judge_config"),
|
||||||
|
"criteria_configs": run_data.get("criteria_configs"),
|
||||||
|
"status": run_data.get("status"),
|
||||||
|
"run_result": run_data.get("run_result"),
|
||||||
|
"score_result": run_data.get("score_result"),
|
||||||
|
"name": run_data.get("name"),
|
||||||
|
"description": run_data.get("description"),
|
||||||
|
"created_at": run_data.get("created_at"),
|
||||||
|
"updated_at": run_data.get("updated_at"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Retrieve from memory
|
||||||
|
return self._runs.get(run_id)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _update_run_status(self, run_id: str, status: EvalRunStatus) -> None:
|
||||||
|
"""
|
||||||
|
Update the status of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
status: The new status
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Update in database
|
||||||
|
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
run_data = response.data[0]
|
||||||
|
run_db = EvalRunDB.model_validate(run_data)
|
||||||
|
run_db.status = status
|
||||||
|
run_db.updated_at = datetime.now()
|
||||||
|
self._db_manager.upsert(run_db)
|
||||||
|
else:
|
||||||
|
# Update in memory
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["status"] = status
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now()
|
||||||
|
|
||||||
|
async def _update_run_result(self, run_id: str, run_result: EvalRunResult) -> None:
|
||||||
|
"""
|
||||||
|
Update the result of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
run_result: The run result
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Update in database
|
||||||
|
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
run_data = response.data[0]
|
||||||
|
run_db = EvalRunDB.model_validate(run_data)
|
||||||
|
run_db.run_result = run_result
|
||||||
|
run_db.updated_at = datetime.now()
|
||||||
|
self._db_manager.upsert(run_db)
|
||||||
|
else:
|
||||||
|
# Update in memory
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["run_result"] = run_result
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now()
|
||||||
|
|
||||||
|
async def _update_score_result(self, run_id: str, score_result: EvalScore) -> None:
|
||||||
|
"""
|
||||||
|
Update the score of an evaluation run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
score_result: The score result
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Update in database
|
||||||
|
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
run_data = response.data[0]
|
||||||
|
run_db = EvalRunDB.model_validate(run_data)
|
||||||
|
run_db.score_result = score_result
|
||||||
|
run_db.updated_at = datetime.now()
|
||||||
|
self._db_manager.upsert(run_db)
|
||||||
|
else:
|
||||||
|
# Update in memory
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["score_result"] = score_result
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now()
|
||||||
|
|
||||||
|
async def _update_run_completed(self, run_id: str, start_time: datetime, end_time: datetime) -> None:
|
||||||
|
"""
|
||||||
|
Update a run as completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
start_time: The start time
|
||||||
|
end_time: The end time
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Update in database
|
||||||
|
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
run_data = response.data[0]
|
||||||
|
run_db = EvalRunDB.model_validate(run_data)
|
||||||
|
run_db.status = EvalRunStatus.COMPLETED
|
||||||
|
run_db.start_time = start_time
|
||||||
|
run_db.end_time = end_time
|
||||||
|
run_db.updated_at = datetime.now()
|
||||||
|
self._db_manager.upsert(run_db)
|
||||||
|
else:
|
||||||
|
# Update in memory
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["status"] = EvalRunStatus.COMPLETED
|
||||||
|
self._runs[run_id]["start_time"] = start_time
|
||||||
|
self._runs[run_id]["end_time"] = end_time
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now()
|
||||||
|
|
||||||
|
async def _update_run_error(self, run_id: str, error_message: str) -> None:
|
||||||
|
"""
|
||||||
|
Update a run with an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run
|
||||||
|
error_message: The error message
|
||||||
|
"""
|
||||||
|
if self._db_manager:
|
||||||
|
# Update in database
|
||||||
|
response = self._db_manager.get(EvalRunDB, filters={"id": int(run_id) if run_id.isdigit() else run_id})
|
||||||
|
|
||||||
|
if response.status and response.data and len(response.data) > 0:
|
||||||
|
run_data = response.data[0]
|
||||||
|
run_db = EvalRunDB.model_validate(run_data)
|
||||||
|
run_db.status = EvalRunStatus.FAILED
|
||||||
|
run_db.error_message = error_message
|
||||||
|
run_db.end_time = datetime.now()
|
||||||
|
run_db.updated_at = datetime.now()
|
||||||
|
self._db_manager.upsert(run_db)
|
||||||
|
else:
|
||||||
|
# Update in memory
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["status"] = EvalRunStatus.FAILED
|
||||||
|
self._runs[run_id]["error_message"] = error_message
|
||||||
|
self._runs[run_id]["end_time"] = datetime.now()
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now()
|
||||||
|
|
||||||
|
async def tabulate_results(self, run_ids: List[str], include_reasons: bool = False) -> TabulatedResults:
|
||||||
|
"""
|
||||||
|
Generate a tabular representation of evaluation results across runs.
|
||||||
|
|
||||||
|
This method collects scores across different runs and organizes them by
|
||||||
|
dimension, making it easy to create visualizations like radar charts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_ids: List of run IDs to include in the tabulation
|
||||||
|
include_reasons: Whether to include scoring reasons in the output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with structured data suitable for visualization
|
||||||
|
"""
|
||||||
|
result: TabulatedResults = {"dimensions": [], "runs": []}
|
||||||
|
|
||||||
|
# Parallelize fetching of run configs and scores
|
||||||
|
fetch_tasks = []
|
||||||
|
for run_id in run_ids:
|
||||||
|
fetch_tasks.append(self._get_run_config(run_id))
|
||||||
|
fetch_tasks.append(self.get_run_score(run_id))
|
||||||
|
|
||||||
|
# Wait for all fetches to complete
|
||||||
|
fetch_results = await asyncio.gather(*fetch_tasks)
|
||||||
|
|
||||||
|
# Process fetched data
|
||||||
|
dimensions_set = set()
|
||||||
|
run_data = {}
|
||||||
|
|
||||||
|
for i in range(0, len(fetch_results), 2):
|
||||||
|
run_id = run_ids[i // 2]
|
||||||
|
run_config = fetch_results[i]
|
||||||
|
score = fetch_results[i + 1]
|
||||||
|
|
||||||
|
# Store run data for later processing
|
||||||
|
run_data[run_id] = (run_config, score)
|
||||||
|
|
||||||
|
# Collect dimensions
|
||||||
|
if score and score.dimension_scores:
|
||||||
|
for dim_score in score.dimension_scores:
|
||||||
|
dimensions_set.add(dim_score.dimension)
|
||||||
|
|
||||||
|
# Convert dimensions to sorted list
|
||||||
|
result["dimensions"] = sorted(list(dimensions_set))
|
||||||
|
|
||||||
|
# Process each run's data
|
||||||
|
for run_id, (run_config, score) in run_data.items():
|
||||||
|
if not run_config or not score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Determine runner type
|
||||||
|
runner_type = "unknown"
|
||||||
|
if run_config.get("runner_config"):
|
||||||
|
runner_config = run_config.get("runner_config")
|
||||||
|
if runner_config is not None and "provider" in runner_config:
|
||||||
|
if "ModelEvalRunner" in runner_config["provider"]:
|
||||||
|
runner_type = "model"
|
||||||
|
elif "TeamEvalRunner" in runner_config["provider"]:
|
||||||
|
runner_type = "team"
|
||||||
|
|
||||||
|
# Get task name
|
||||||
|
task = run_config.get("task")
|
||||||
|
task_name = task.name if task else "Unknown Task"
|
||||||
|
|
||||||
|
# Create run entry
|
||||||
|
run_entry: RunEntry = {
|
||||||
|
"id": run_id,
|
||||||
|
"name": run_config.get("name", f"Run {run_id}"),
|
||||||
|
"task_name": task_name,
|
||||||
|
"runner_type": runner_type,
|
||||||
|
"overall_score": score.overall_score,
|
||||||
|
"scores": [],
|
||||||
|
"reasons": [] if include_reasons else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build dimension lookup map for O(1) access
|
||||||
|
dim_map = {ds.dimension: ds for ds in score.dimension_scores}
|
||||||
|
|
||||||
|
# Populate scores aligned with dimensions
|
||||||
|
for dim in result["dimensions"]:
|
||||||
|
dim_score = dim_map.get(dim)
|
||||||
|
if dim_score:
|
||||||
|
run_entry["scores"].append(dim_score.score)
|
||||||
|
if include_reasons:
|
||||||
|
run_entry["reasons"].append(dim_score.reason) # type: ignore
|
||||||
|
else:
|
||||||
|
run_entry["scores"].append(None)
|
||||||
|
if include_reasons:
|
||||||
|
run_entry["reasons"].append(None) # type: ignore
|
||||||
|
|
||||||
|
result["runs"].append(run_entry)
|
||||||
|
|
||||||
|
return result
|
|
@ -0,0 +1,201 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
|
from autogen_agentchat.base import TaskResult, Team
|
||||||
|
from autogen_agentchat.messages import ChatMessage, MultiModalMessage, TextMessage
|
||||||
|
from autogen_core import CancellationToken, Component, ComponentBase, ComponentModel, Image
|
||||||
|
from autogen_core.models import ChatCompletionClient, UserMessage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ..datamodel.eval import EvalRunResult, EvalTask
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEvalRunnerConfig(BaseModel):
|
||||||
|
"""Base configuration for evaluation runners."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEvalRunner(ABC, ComponentBase[BaseEvalRunnerConfig]):
|
||||||
|
"""Base class for evaluation runners that defines the interface for running evaluations.
|
||||||
|
|
||||||
|
This class provides the core interface that all evaluation runners must implement.
|
||||||
|
Subclasses should implement the run method to define how a specific evaluation is executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
component_type = "eval_runner"
|
||||||
|
|
||||||
|
def __init__(self, name: str, description: str = "", metadata: Optional[Dict[str, Any]] = None):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run(self, task: EvalTask, cancellation_token: Optional[CancellationToken] = None) -> EvalRunResult:
|
||||||
|
"""Run the evaluation on the provided task and return a result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task to evaluate
|
||||||
|
cancellation_token: Optional token to cancel the evaluation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EvaluationResult: The result of the evaluation
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _to_config(self) -> BaseEvalRunnerConfig:
|
||||||
|
"""Convert the runner configuration to a configuration object for serialization."""
|
||||||
|
return BaseEvalRunnerConfig(name=self.name, description=self.description, metadata=self.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvalRunnerConfig(BaseEvalRunnerConfig):
|
||||||
|
"""Configuration for ModelEvalRunner."""
|
||||||
|
|
||||||
|
model_client: ComponentModel
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvalRunner(BaseEvalRunner, Component[ModelEvalRunnerConfig]):
|
||||||
|
"""Evaluation runner that uses a single LLM to process tasks.
|
||||||
|
|
||||||
|
This runner sends the task directly to a model client and returns the response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
component_config_schema = ModelEvalRunnerConfig
|
||||||
|
component_type = "eval_runner"
|
||||||
|
component_provider_override = "autogenstudio.eval.runners.ModelEvalRunner"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_client: ChatCompletionClient,
|
||||||
|
name: str = "Model Runner",
|
||||||
|
description: str = "Evaluates tasks using a single LLM",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, description, metadata)
|
||||||
|
self.model_client = model_client
|
||||||
|
|
||||||
|
async def run(self, task: EvalTask, cancellation_token: Optional[CancellationToken] = None) -> EvalRunResult:
|
||||||
|
"""Run the task with the model client and return the result."""
|
||||||
|
# Create initial result object
|
||||||
|
result = EvalRunResult()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_input = []
|
||||||
|
if isinstance(task.input, str):
|
||||||
|
text_message = UserMessage(content=task.input, source="user")
|
||||||
|
model_input.append(text_message)
|
||||||
|
elif isinstance(task.input, list):
|
||||||
|
message_content = [x for x in task.input]
|
||||||
|
model_input.append(UserMessage(content=message_content, source="user"))
|
||||||
|
# Run with the model
|
||||||
|
model_result = await self.model_client.create(messages=model_input, cancellation_token=cancellation_token)
|
||||||
|
|
||||||
|
model_response = model_result.content if isinstance(model_result, str) else model_result.model_dump()
|
||||||
|
|
||||||
|
task_result = TaskResult(
|
||||||
|
messages=[TextMessage(content=str(model_response), source="model")],
|
||||||
|
)
|
||||||
|
result = EvalRunResult(result=task_result, status=True, start_time=datetime.now(), end_time=datetime.now())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result = EvalRunResult(status=False, error=str(e), end_time=datetime.now())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _to_config(self) -> ModelEvalRunnerConfig:
|
||||||
|
"""Convert to configuration object including model client configuration."""
|
||||||
|
base_config = super()._to_config()
|
||||||
|
return ModelEvalRunnerConfig(
|
||||||
|
name=base_config.name,
|
||||||
|
description=base_config.description,
|
||||||
|
metadata=base_config.metadata,
|
||||||
|
model_client=self.model_client.dump_component(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config: ModelEvalRunnerConfig) -> Self:
|
||||||
|
"""Create from configuration object with serialized model client."""
|
||||||
|
model_client = ChatCompletionClient.load_component(config.model_client)
|
||||||
|
return cls(
|
||||||
|
name=config.name,
|
||||||
|
description=config.description,
|
||||||
|
metadata=config.metadata,
|
||||||
|
model_client=model_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TeamEvalRunnerConfig(BaseEvalRunnerConfig):
|
||||||
|
"""Configuration for TeamEvalRunner."""
|
||||||
|
|
||||||
|
team: ComponentModel
|
||||||
|
|
||||||
|
|
||||||
|
class TeamEvalRunner(BaseEvalRunner, Component[TeamEvalRunnerConfig]):
|
||||||
|
"""Evaluation runner that uses a team of agents to process tasks.
|
||||||
|
|
||||||
|
This runner creates and runs a team based on a team configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
component_config_schema = TeamEvalRunnerConfig
|
||||||
|
component_type = "eval_runner"
|
||||||
|
component_provider_override = "autogenstudio.eval.runners.TeamEvalRunner"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
team: Union[Team, ComponentModel],
|
||||||
|
name: str = "Team Runner",
|
||||||
|
description: str = "Evaluates tasks using a team of agents",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(name, description, metadata)
|
||||||
|
self._team = team if isinstance(team, Team) else Team.load_component(team)
|
||||||
|
|
||||||
|
async def run(self, task: EvalTask, cancellation_token: Optional[CancellationToken] = None) -> EvalRunResult:
|
||||||
|
"""Run the task with the team and return the result."""
|
||||||
|
# Create initial result object
|
||||||
|
result = EvalRunResult()
|
||||||
|
|
||||||
|
try:
|
||||||
|
team_task: Sequence[ChatMessage] = []
|
||||||
|
if isinstance(task.input, str):
|
||||||
|
team_task.append(TextMessage(content=task.input, source="user"))
|
||||||
|
if isinstance(task.input, list):
|
||||||
|
for message in task.input:
|
||||||
|
if isinstance(message, str):
|
||||||
|
team_task.append(TextMessage(content=message, source="user"))
|
||||||
|
elif isinstance(message, Image):
|
||||||
|
team_task.append(MultiModalMessage(source="user", content=[message]))
|
||||||
|
|
||||||
|
# Run task with team
|
||||||
|
team_result = await self._team.run(task=team_task, cancellation_token=cancellation_token)
|
||||||
|
|
||||||
|
result = EvalRunResult(result=team_result, status=True, start_time=datetime.now(), end_time=datetime.now())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result = EvalRunResult(status=False, error=str(e), end_time=datetime.now())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _to_config(self) -> TeamEvalRunnerConfig:
|
||||||
|
"""Convert to configuration object including team configuration."""
|
||||||
|
base_config = super()._to_config()
|
||||||
|
return TeamEvalRunnerConfig(
|
||||||
|
name=base_config.name,
|
||||||
|
description=base_config.description,
|
||||||
|
metadata=base_config.metadata,
|
||||||
|
team=self._team.dump_component(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_config(cls, config: TeamEvalRunnerConfig) -> Self:
|
||||||
|
"""Create from configuration object with serialized team configuration."""
|
||||||
|
return cls(
|
||||||
|
team=Team.load_component(config.team),
|
||||||
|
name=config.name,
|
||||||
|
description=config.description,
|
||||||
|
metadata=config.metadata,
|
||||||
|
)
|
|
@ -4,18 +4,19 @@ import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator, Callable, List, Optional, Union
|
from typing import AsyncGenerator, Callable, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import yaml
|
import yaml
|
||||||
from autogen_agentchat.agents import UserProxyAgent
|
from autogen_agentchat.agents import UserProxyAgent
|
||||||
from autogen_agentchat.base import TaskResult, Team
|
from autogen_agentchat.base import TaskResult
|
||||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
|
||||||
from autogen_agentchat.teams import BaseGroupChat
|
from autogen_agentchat.teams import BaseGroupChat
|
||||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, Component, ComponentModel
|
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, ComponentModel
|
||||||
from autogen_core.logging import LLMCallEvent
|
from autogen_core.logging import LLMCallEvent
|
||||||
|
|
||||||
from ..datamodel.types import EnvironmentVariable, LLMCallEventMessage, TeamResult
|
from ..datamodel.types import EnvironmentVariable, LLMCallEventMessage, TeamResult
|
||||||
|
from ..web.managers.run_context import RunContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -35,6 +36,10 @@ class RunEventLogger(logging.Handler):
|
||||||
class TeamManager:
|
class TeamManager:
|
||||||
"""Manages team operations including loading configs and running teams"""
|
"""Manages team operations including loading configs and running teams"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._team: Optional[BaseGroupChat] = None
|
||||||
|
self._run_context = RunContext()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def load_from_file(path: Union[str, Path]) -> dict:
|
async def load_from_file(path: Union[str, Path]) -> dict:
|
||||||
"""Load team configuration from JSON/YAML file"""
|
"""Load team configuration from JSON/YAML file"""
|
||||||
|
@ -87,17 +92,17 @@ class TeamManager:
|
||||||
for var in env_vars:
|
for var in env_vars:
|
||||||
os.environ[var.name] = var.value
|
os.environ[var.name] = var.value
|
||||||
|
|
||||||
team: BaseGroupChat = BaseGroupChat.load_component(config)
|
self._team = BaseGroupChat.load_component(config)
|
||||||
|
|
||||||
for agent in team._participants:
|
for agent in self._team._participants:
|
||||||
if hasattr(agent, "input_func") and isinstance(agent, UserProxyAgent) and input_func:
|
if hasattr(agent, "input_func") and isinstance(agent, UserProxyAgent) and input_func:
|
||||||
agent.input_func = input_func
|
agent.input_func = input_func
|
||||||
|
|
||||||
return team
|
return self._team
|
||||||
|
|
||||||
async def run_stream(
|
async def run_stream(
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None,
|
||||||
team_config: Union[str, Path, dict, ComponentModel],
|
team_config: Union[str, Path, dict, ComponentModel],
|
||||||
input_func: Optional[Callable] = None,
|
input_func: Optional[Callable] = None,
|
||||||
cancellation_token: Optional[CancellationToken] = None,
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
@ -142,7 +147,7 @@ class TeamManager:
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None,
|
||||||
team_config: Union[str, Path, dict, ComponentModel],
|
team_config: Union[str, Path, dict, ComponentModel],
|
||||||
input_func: Optional[Callable] = None,
|
input_func: Optional[Callable] = None,
|
||||||
cancellation_token: Optional[CancellationToken] = None,
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
|
|
@ -1,262 +1,71 @@
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
from typing import Sequence
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from autogen_agentchat.messages import ChatMessage, MultiModalMessage, TextMessage
|
||||||
|
from autogen_core import Image
|
||||||
|
from autogen_core.models import UserMessage
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..version import APP_NAME
|
|
||||||
|
|
||||||
|
def construct_task(query: str, files: list[dict] | None = None) -> Sequence[ChatMessage]:
|
||||||
def sha256_hash(text: str) -> str:
|
|
||||||
"""
|
"""
|
||||||
Compute the SHA-256 hash of a given text.
|
Construct a task from a query string and list of files.
|
||||||
|
Returns a list of ChatMessage objects suitable for processing by the agent system.
|
||||||
|
|
||||||
:param text: The string to hash
|
Args:
|
||||||
:return: The SHA-256 hash of the text, hex-encoded.
|
query: The text query from the user
|
||||||
|
files: List of file objects with properties name, content, and type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of BaseChatMessage objects (TextMessage, MultiModalMessage)
|
||||||
"""
|
"""
|
||||||
return hashlib.sha256(text.encode()).hexdigest()
|
if files is None:
|
||||||
|
files = []
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
def check_and_cast_datetime_fields(obj: Any) -> Any:
|
# Add the user's text query as a TextMessage
|
||||||
if hasattr(obj, "created_at") and isinstance(obj.created_at, str):
|
if query:
|
||||||
obj.created_at = str_to_datetime(obj.created_at)
|
messages.append(TextMessage(source="user", content=query))
|
||||||
|
|
||||||
if hasattr(obj, "updated_at") and isinstance(obj.updated_at, str):
|
|
||||||
obj.updated_at = str_to_datetime(obj.updated_at)
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def str_to_datetime(dt_str: str) -> datetime:
|
|
||||||
if dt_str[-1] == "Z":
|
|
||||||
# Replace 'Z' with '+00:00' for UTC timezone
|
|
||||||
dt_str = dt_str[:-1] + "+00:00"
|
|
||||||
return datetime.fromisoformat(dt_str)
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_type(file_path: str) -> str:
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Get file type determined by the file extension. If the file extension is not
|
|
||||||
recognized, 'unknown' will be used as the file type.
|
|
||||||
|
|
||||||
:param file_path: The path to the file to be serialized.
|
|
||||||
:return: A string containing the file type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Extended list of file extensions for code and text files
|
|
||||||
CODE_EXTENSIONS = {
|
|
||||||
".py",
|
|
||||||
".js",
|
|
||||||
".jsx",
|
|
||||||
".java",
|
|
||||||
".c",
|
|
||||||
".cpp",
|
|
||||||
".cs",
|
|
||||||
".ts",
|
|
||||||
".tsx",
|
|
||||||
".html",
|
|
||||||
".css",
|
|
||||||
".scss",
|
|
||||||
".less",
|
|
||||||
".json",
|
|
||||||
".xml",
|
|
||||||
".yaml",
|
|
||||||
".yml",
|
|
||||||
".md",
|
|
||||||
".rst",
|
|
||||||
".tex",
|
|
||||||
".sh",
|
|
||||||
".bat",
|
|
||||||
".ps1",
|
|
||||||
".php",
|
|
||||||
".rb",
|
|
||||||
".go",
|
|
||||||
".swift",
|
|
||||||
".kt",
|
|
||||||
".hs",
|
|
||||||
".scala",
|
|
||||||
".lua",
|
|
||||||
".pl",
|
|
||||||
".sql",
|
|
||||||
".config",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Supported spreadsheet extensions
|
|
||||||
CSV_EXTENSIONS = {".csv", ".xlsx"}
|
|
||||||
|
|
||||||
# Supported image extensions
|
|
||||||
IMAGE_EXTENSIONS = {
|
|
||||||
".png",
|
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
".gif",
|
|
||||||
".bmp",
|
|
||||||
".tiff",
|
|
||||||
".svg",
|
|
||||||
".webp",
|
|
||||||
}
|
|
||||||
# Supported (web) video extensions
|
|
||||||
VIDEO_EXTENSIONS = {".mp4", ".webm", ".ogg", ".mov", ".avi", ".wmv"}
|
|
||||||
|
|
||||||
# Supported PDF extension
|
|
||||||
PDF_EXTENSION = ".pdf"
|
|
||||||
|
|
||||||
# Determine the file extension
|
|
||||||
_, file_extension = os.path.splitext(file_path)
|
|
||||||
|
|
||||||
# Determine the file type based on the extension
|
|
||||||
if file_extension in CODE_EXTENSIONS:
|
|
||||||
file_type = "code"
|
|
||||||
elif file_extension in CSV_EXTENSIONS:
|
|
||||||
file_type = "csv"
|
|
||||||
elif file_extension in IMAGE_EXTENSIONS:
|
|
||||||
file_type = "image"
|
|
||||||
elif file_extension == PDF_EXTENSION:
|
|
||||||
file_type = "pdf"
|
|
||||||
elif file_extension in VIDEO_EXTENSIONS:
|
|
||||||
file_type = "video"
|
|
||||||
else:
|
|
||||||
file_type = "unknown"
|
|
||||||
|
|
||||||
return file_type
|
|
||||||
|
|
||||||
|
|
||||||
def get_modified_files(start_timestamp: float, end_timestamp: float, source_dir: str) -> List[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Identify files from source_dir that were modified within a specified timestamp range.
|
|
||||||
The function excludes files with certain file extensions and names.
|
|
||||||
|
|
||||||
:param start_timestamp: The floating-point number representing the start timestamp to filter modified files.
|
|
||||||
:param end_timestamp: The floating-point number representing the end timestamp to filter modified files.
|
|
||||||
:param source_dir: The directory to search for modified files.
|
|
||||||
|
|
||||||
:return: A list of dictionaries with details of relative file paths that were modified.
|
|
||||||
Dictionary format: {path: "", name: "", extension: "", type: ""}
|
|
||||||
Files with extensions "__pycache__", "*.pyc", "__init__.py", and "*.cache"
|
|
||||||
are ignored.
|
|
||||||
"""
|
|
||||||
modified_files = []
|
|
||||||
ignore_extensions = {".pyc", ".cache"}
|
|
||||||
ignore_files = {"__pycache__", "__init__.py"}
|
|
||||||
|
|
||||||
# Walk through the directory tree
|
|
||||||
for root, dirs, files in os.walk(source_dir):
|
|
||||||
# Update directories and files to exclude those to be ignored
|
|
||||||
dirs[:] = [d for d in dirs if d not in ignore_files]
|
|
||||||
files[:] = [f for f in files if f not in ignore_files and os.path.splitext(f)[1] not in ignore_extensions]
|
|
||||||
|
|
||||||
|
# Process each file based on its type
|
||||||
for file in files:
|
for file in files:
|
||||||
file_path = os.path.join(root, file)
|
|
||||||
file_mtime = os.path.getmtime(file_path)
|
|
||||||
|
|
||||||
# Verify if the file was modified within the given timestamp range
|
|
||||||
if start_timestamp <= file_mtime <= end_timestamp:
|
|
||||||
file_relative_path = (
|
|
||||||
"files/user" + file_path.split("files/user", 1)[1] if "files/user" in file_path else ""
|
|
||||||
)
|
|
||||||
file_type = get_file_type(file_path)
|
|
||||||
|
|
||||||
file_dict = {
|
|
||||||
"path": file_relative_path,
|
|
||||||
"name": os.path.basename(file),
|
|
||||||
# Remove the dot
|
|
||||||
"extension": os.path.splitext(file)[1].lstrip("."),
|
|
||||||
"type": file_type,
|
|
||||||
}
|
|
||||||
modified_files.append(file_dict)
|
|
||||||
|
|
||||||
# Sort the modified files by extension
|
|
||||||
modified_files.sort(key=lambda x: x["extension"])
|
|
||||||
return modified_files
|
|
||||||
|
|
||||||
|
|
||||||
def get_app_root() -> str:
|
|
||||||
"""
|
|
||||||
Get the root directory of the application.
|
|
||||||
|
|
||||||
:return: The root directory of the application.
|
|
||||||
"""
|
|
||||||
app_name = f".{APP_NAME}"
|
|
||||||
default_app_root = os.path.join(os.path.expanduser("~"), app_name)
|
|
||||||
if not os.path.exists(default_app_root):
|
|
||||||
os.makedirs(default_app_root, exist_ok=True)
|
|
||||||
app_root = os.environ.get("AUTOGENSTUDIO_APPDIR") or default_app_root
|
|
||||||
return app_root
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_uri(app_root: str) -> str:
|
|
||||||
"""
|
|
||||||
Get the default database URI for the application.
|
|
||||||
|
|
||||||
:param app_root: The root directory of the application.
|
|
||||||
:return: The default database URI.
|
|
||||||
"""
|
|
||||||
db_uri = f"sqlite:///{os.path.join(app_root, 'database.sqlite')}"
|
|
||||||
db_uri = os.environ.get("AUTOGENSTUDIO_DATABASE_URI") or db_uri
|
|
||||||
logger.info(f"Using database URI: {db_uri}")
|
|
||||||
return db_uri
|
|
||||||
|
|
||||||
|
|
||||||
def init_app_folders(app_file_path: str) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Initialize folders needed for a web server, such as static file directories
|
|
||||||
and user-specific data directories. Also load any .env file if it exists.
|
|
||||||
|
|
||||||
:param root_file_path: The root directory where webserver folders will be created
|
|
||||||
:return: A dictionary with the path of each created folder
|
|
||||||
"""
|
|
||||||
app_root = get_app_root()
|
|
||||||
|
|
||||||
if not os.path.exists(app_root):
|
|
||||||
os.makedirs(app_root, exist_ok=True)
|
|
||||||
|
|
||||||
# load .env file if it exists
|
|
||||||
env_file = os.path.join(app_root, ".env")
|
|
||||||
if os.path.exists(env_file):
|
|
||||||
logger.info(f"Loaded environment variables from {env_file}")
|
|
||||||
load_dotenv(env_file)
|
|
||||||
|
|
||||||
files_static_root = os.path.join(app_root, "files/")
|
|
||||||
static_folder_root = os.path.join(app_file_path, "ui")
|
|
||||||
|
|
||||||
os.makedirs(files_static_root, exist_ok=True)
|
|
||||||
os.makedirs(os.path.join(files_static_root, "user"), exist_ok=True)
|
|
||||||
os.makedirs(static_folder_root, exist_ok=True)
|
|
||||||
folders = {
|
|
||||||
"files_static_root": files_static_root,
|
|
||||||
"static_folder_root": static_folder_root,
|
|
||||||
"app_root": app_root,
|
|
||||||
"database_engine_uri": get_db_uri(app_root=app_root),
|
|
||||||
}
|
|
||||||
logger.info(f"Initialized application data folder: {app_root}")
|
|
||||||
return folders
|
|
||||||
|
|
||||||
|
|
||||||
class Version:
|
|
||||||
def __init__(self, ver_str: str):
|
|
||||||
try:
|
try:
|
||||||
# Split into major.minor.patch
|
if file.get("type", "").startswith("image/"):
|
||||||
self.major, self.minor, self.patch = map(int, ver_str.split("."))
|
# Handle image file using from_base64 method
|
||||||
except (ValueError, AttributeError) as err:
|
# The content is already base64 encoded according to the convertFilesToBase64 function
|
||||||
raise ValueError(f"Invalid version format: {ver_str}. Expected: major.minor.patch") from err
|
image = Image.from_base64(file["content"])
|
||||||
|
messages.append(
|
||||||
|
MultiModalMessage(
|
||||||
|
source="user", content=[image], metadata={"filename": file.get("name", "unknown.img")}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif file.get("type", "").startswith("text/"):
|
||||||
|
# Handle text file as TextMessage
|
||||||
|
text_content = base64.b64decode(file["content"]).decode("utf-8")
|
||||||
|
messages.append(
|
||||||
|
TextMessage(
|
||||||
|
source="user", content=text_content, metadata={"filename": file.get("name", "unknown.txt")}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Log unsupported file types but still try to process based on best guess
|
||||||
|
logger.warning(f"Potentially unsupported file type: {file.get('type')} for file {file.get('name')}")
|
||||||
|
if file.get("type", "").startswith("application/"):
|
||||||
|
# Try to treat as text if it's an application type (like JSON)
|
||||||
|
text_content = base64.b64decode(file["content"]).decode("utf-8")
|
||||||
|
messages.append(
|
||||||
|
TextMessage(
|
||||||
|
source="user",
|
||||||
|
content=text_content,
|
||||||
|
metadata={
|
||||||
|
"filename": file.get("name", "unknown.file"),
|
||||||
|
"filetype": file.get("type", "unknown"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing file {file.get('name')}: {str(e)}")
|
||||||
|
# Continue processing other files even if one fails
|
||||||
|
|
||||||
def __str__(self):
|
return messages
|
||||||
return f"{self.major}.{self.minor}.{self.patch}"
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if isinstance(other, str):
|
|
||||||
other = Version(other)
|
|
||||||
return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
|
||||||
if isinstance(other, str):
|
|
||||||
other = Version(other)
|
|
||||||
return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch)
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ class ComponentTestService:
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
logs.append(
|
logs.append(
|
||||||
f"Agent responded with: {response.chat_message.content} to the question : {test_question}"
|
f"Agent responded with: {response.chat_message.to_text()} to the question : {test_question}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logs.append("Agent did not return a valid response")
|
logs.append("Agent did not return a valid response")
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
VERSION = "0.4.1"
|
VERSION = "0.4.2"
|
||||||
__version__ = VERSION
|
__version__ = VERSION
|
||||||
APP_NAME = "autogenstudio"
|
APP_NAME = "autogenstudio"
|
||||||
|
|
|
@ -109,8 +109,8 @@ async def register_auth_dependencies(app: FastAPI, auth_manager: AuthManager) ->
|
||||||
|
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
# print(" *** Route: ", route.path)
|
# print(" *** Route: ", route.path)
|
||||||
if hasattr(route, "app") and isinstance(route.app, FastAPI):
|
if hasattr(route, "app") and isinstance(route.app, FastAPI): # type: ignore
|
||||||
route.app.state.auth_manager = auth_manager
|
route.app.state.auth_manager = auth_manager # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Manager initialization and cleanup
|
# Manager initialization and cleanup
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
from .connection import WebSocketManager
|
# from .connection import WebSocketManager
|
||||||
|
|
|
@ -2,12 +2,13 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
from autogen_agentchat.base._task import TaskResult
|
from autogen_agentchat.base import TaskResult
|
||||||
from autogen_agentchat.messages import (
|
from autogen_agentchat.messages import (
|
||||||
BaseAgentEvent,
|
BaseAgentEvent,
|
||||||
BaseChatMessage,
|
BaseChatMessage,
|
||||||
|
ChatMessage,
|
||||||
HandoffMessage,
|
HandoffMessage,
|
||||||
ModelClientStreamingChunkEvent,
|
ModelClientStreamingChunkEvent,
|
||||||
MultiModalMessage,
|
MultiModalMessage,
|
||||||
|
@ -32,6 +33,7 @@ from ...datamodel import (
|
||||||
TeamResult,
|
TeamResult,
|
||||||
)
|
)
|
||||||
from ...teammanager import TeamManager
|
from ...teammanager import TeamManager
|
||||||
|
from .run_context import RunContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -79,11 +81,14 @@ class WebSocketManager:
|
||||||
logger.error(f"Connection error for run {run_id}: {e}")
|
logger.error(f"Connection error for run {run_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def start_stream(self, run_id: int, task: str, team_config: dict) -> None:
|
async def start_stream(
|
||||||
|
self, run_id: int, task: str | ChatMessage | Sequence[ChatMessage] | None, team_config: Dict
|
||||||
|
) -> None:
|
||||||
"""Start streaming task execution with proper run management"""
|
"""Start streaming task execution with proper run management"""
|
||||||
if run_id not in self._connections or run_id in self._closed_connections:
|
if run_id not in self._connections or run_id in self._closed_connections:
|
||||||
raise ValueError(f"No active connection for run {run_id}")
|
raise ValueError(f"No active connection for run {run_id}")
|
||||||
|
|
||||||
|
with RunContext.populate_context(run_id=run_id):
|
||||||
team_manager = TeamManager()
|
team_manager = TeamManager()
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
self._cancellation_tokens[run_id] = cancellation_token
|
self._cancellation_tokens[run_id] = cancellation_token
|
||||||
|
@ -92,11 +97,12 @@ class WebSocketManager:
|
||||||
try:
|
try:
|
||||||
# Update run with task and status
|
# Update run with task and status
|
||||||
run = await self._get_run(run_id)
|
run = await self._get_run(run_id)
|
||||||
|
|
||||||
|
if run is not None and run.user_id:
|
||||||
# get user Settings
|
# get user Settings
|
||||||
user_settings = await self._get_settings(run.user_id)
|
user_settings = await self._get_settings(run.user_id)
|
||||||
env_vars = SettingsConfig(**user_settings.config).environment if user_settings else None
|
env_vars = SettingsConfig(**user_settings.config).environment if user_settings else None # type: ignore
|
||||||
if run:
|
run.task = self._convert_images_in_dict(MessageConfig(content=task, source="user").model_dump())
|
||||||
run.task = MessageConfig(content=task, source="user").model_dump()
|
|
||||||
run.status = RunStatus.ACTIVE
|
run.status = RunStatus.ACTIVE
|
||||||
self.db_manager.upsert(run)
|
self.db_manager.upsert(run)
|
||||||
|
|
||||||
|
@ -170,7 +176,7 @@ class WebSocketManager:
|
||||||
db_message = Message(
|
db_message = Message(
|
||||||
session_id=run.session_id,
|
session_id=run.session_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
config=message.model_dump(),
|
config=self._convert_images_in_dict(message.model_dump()),
|
||||||
user_id=None, # You might want to pass this from somewhere
|
user_id=None, # You might want to pass this from somewhere
|
||||||
)
|
)
|
||||||
self.db_manager.upsert(db_message)
|
self.db_manager.upsert(db_message)
|
||||||
|
@ -183,7 +189,7 @@ class WebSocketManager:
|
||||||
if run:
|
if run:
|
||||||
run.status = status
|
run.status = status
|
||||||
if team_result:
|
if team_result:
|
||||||
run.team_result = team_result
|
run.team_result = self._convert_images_in_dict(team_result)
|
||||||
if error:
|
if error:
|
||||||
run.error_message = error
|
run.error_message = error
|
||||||
self.db_manager.upsert(run)
|
self.db_manager.upsert(run)
|
||||||
|
@ -269,6 +275,18 @@ class WebSocketManager:
|
||||||
self._cancellation_tokens.pop(run_id, None)
|
self._cancellation_tokens.pop(run_id, None)
|
||||||
self._input_responses.pop(run_id, None)
|
self._input_responses.pop(run_id, None)
|
||||||
|
|
||||||
|
def _convert_images_in_dict(self, obj: Any) -> Any:
|
||||||
|
"""Recursively find and convert Image objects in dictionaries and lists"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: self._convert_images_in_dict(v) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [self._convert_images_in_dict(item) for item in obj]
|
||||||
|
elif isinstance(obj, AGImage): # Assuming you've imported AGImage
|
||||||
|
# Convert the Image object to a serializable format
|
||||||
|
return {"type": "image", "url": f"data:image/png;base64,{obj.to_base64()}", "alt": "Image"}
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
async def _send_message(self, run_id: int, message: dict) -> None:
|
async def _send_message(self, run_id: int, message: dict) -> None:
|
||||||
"""Send a message through the WebSocket with connection state checking
|
"""Send a message through the WebSocket with connection state checking
|
||||||
|
|
||||||
|
@ -283,7 +301,7 @@ class WebSocketManager:
|
||||||
try:
|
try:
|
||||||
if run_id in self._connections:
|
if run_id in self._connections:
|
||||||
websocket = self._connections[run_id]
|
websocket = self._connections[run_id]
|
||||||
await websocket.send_json(message)
|
await websocket.send_json(self._convert_images_in_dict(message))
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
|
logger.warning(f"WebSocket disconnected while sending message for run {run_id}")
|
||||||
await self.disconnect(run_id)
|
await self.disconnect(run_id)
|
||||||
|
@ -330,13 +348,20 @@ class WebSocketManager:
|
||||||
try:
|
try:
|
||||||
if isinstance(message, MultiModalMessage):
|
if isinstance(message, MultiModalMessage):
|
||||||
message_dump = message.model_dump()
|
message_dump = message.model_dump()
|
||||||
message_dump["content"] = [
|
|
||||||
message_dump["content"][0],
|
message_content = []
|
||||||
|
for row in message_dump["content"]:
|
||||||
|
if isinstance(row, dict) and "data" in row:
|
||||||
|
message_content.append(
|
||||||
{
|
{
|
||||||
"url": f"data:image/png;base64,{message_dump['content'][1]['data']}",
|
"url": f"data:image/png;base64,{row['data']}",
|
||||||
"alt": "WebSurfer Screenshot",
|
"alt": "WebSurfer Screenshot",
|
||||||
},
|
}
|
||||||
]
|
)
|
||||||
|
else:
|
||||||
|
message_content.append(row)
|
||||||
|
message_dump["content"] = message_content
|
||||||
|
|
||||||
return {"type": "message", "data": message_dump}
|
return {"type": "message", "data": message_dump}
|
||||||
|
|
||||||
elif isinstance(message, TeamResult):
|
elif isinstance(message, TeamResult):
|
||||||
|
@ -365,6 +390,7 @@ class WebSocketManager:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Message formatting error: {e}")
|
logger.error(f"Message formatting error: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _get_run(self, run_id: int) -> Optional[Run]:
|
async def _get_run(self, run_id: int) -> Optional[Run]:
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, ClassVar, Generator
|
||||||
|
|
||||||
|
|
||||||
|
class RunContext:
|
||||||
|
RUN_CONTEXT_VAR: ClassVar[ContextVar] = ContextVar("RUN_CONTEXT_VAR")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def populate_context(cls, run_id) -> Generator[None, Any, None]:
|
||||||
|
token = RunContext.RUN_CONTEXT_VAR.set(run_id)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
RunContext.RUN_CONTEXT_VAR.reset(token)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def current_run_id(cls) -> str:
|
||||||
|
try:
|
||||||
|
return cls.RUN_CONTEXT_VAR.get()
|
||||||
|
except LookupError as e:
|
||||||
|
raise RuntimeError("Error getting run id") from e
|
|
@ -1,10 +1,11 @@
|
||||||
# api/routes/sessions.py
|
# api/routes/sessions.py
|
||||||
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ...datamodel import Message, Run, Session
|
from ...datamodel import Message, Response, Run, Session
|
||||||
from ..deps import get_db
|
from ..deps import get_db
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -27,12 +28,16 @@ async def get_session(session_id: int, user_id: str, db=Depends(get_db)) -> Dict
|
||||||
|
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
async def create_session(session: Session, db=Depends(get_db)) -> Dict:
|
async def create_session(session: Session, db=Depends(get_db)) -> Response:
|
||||||
"""Create a new session"""
|
"""Create a new session"""
|
||||||
|
try:
|
||||||
response = db.upsert(session)
|
response = db.upsert(session)
|
||||||
if not response.status:
|
if not response.status:
|
||||||
raise HTTPException(status_code=400, detail=response.message)
|
return Response(status=False, message=f"Failed to create session: {response.message}")
|
||||||
return {"status": True, "data": response.data}
|
return Response(status=True, data=response.data, message="Session created successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating session: {str(e)}")
|
||||||
|
return Response(status=False, message=f"Failed to create session: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{session_id}")
|
@router.put("/{session_id}")
|
||||||
|
|
|
@ -14,6 +14,7 @@ router = APIRouter()
|
||||||
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
|
async def list_teams(user_id: str, db=Depends(get_db)) -> Dict:
|
||||||
"""List all teams for a user"""
|
"""List all teams for a user"""
|
||||||
response = db.get(Team, filters={"user_id": user_id})
|
response = db.get(Team, filters={"user_id": user_id})
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
default_gallery = create_default_gallery()
|
default_gallery = create_default_gallery()
|
||||||
default_team = Team(user_id=user_id, component=default_gallery.components.teams[0].model_dump())
|
default_team = Team(user_id=user_id, component=default_gallery.components.teams[0].model_dump())
|
||||||
|
|
|
@ -8,10 +8,11 @@ from fastapi.websockets import WebSocketState
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ...datamodel import Run, RunStatus
|
from ...datamodel import Run, RunStatus
|
||||||
|
from ...utils.utils import construct_task
|
||||||
from ..auth.dependencies import get_ws_auth_manager
|
from ..auth.dependencies import get_ws_auth_manager
|
||||||
from ..auth.wsauth import WebSocketAuthHandler
|
from ..auth.wsauth import WebSocketAuthHandler
|
||||||
from ..deps import get_db, get_websocket_manager
|
from ..deps import get_db, get_websocket_manager
|
||||||
from ..managers import WebSocketManager
|
from ..managers.connection import WebSocketManager
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -26,21 +27,6 @@ async def run_websocket(
|
||||||
):
|
):
|
||||||
"""WebSocket endpoint for run communication"""
|
"""WebSocket endpoint for run communication"""
|
||||||
|
|
||||||
async def start_stream_wrapper(run_id, task, team_config):
|
|
||||||
try:
|
|
||||||
await ws_manager.start_stream(run_id, task, team_config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in start_stream for run {run_id}: {str(e)}")
|
|
||||||
# Optionally notify the client about the error
|
|
||||||
if websocket.client_state == WebSocketState.CONNECTED:
|
|
||||||
await websocket.send_json(
|
|
||||||
{
|
|
||||||
"type": "error",
|
|
||||||
"error": f"Stream processing error: {str(e)}",
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify run exists before connecting
|
# Verify run exists before connecting
|
||||||
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
|
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
|
||||||
|
@ -98,11 +84,12 @@ async def run_websocket(
|
||||||
if message.get("type") == "start":
|
if message.get("type") == "start":
|
||||||
# Handle start message
|
# Handle start message
|
||||||
logger.info(f"Received start request for run {run_id}")
|
logger.info(f"Received start request for run {run_id}")
|
||||||
task = message.get("task")
|
task = construct_task(query=message.get("task"), files=message.get("files"))
|
||||||
|
|
||||||
team_config = message.get("team_config")
|
team_config = message.get("team_config")
|
||||||
if task and team_config:
|
if task and team_config:
|
||||||
# Start the stream in a separate task
|
# Start the stream in a separate task
|
||||||
asyncio.create_task(start_stream_wrapper(run_id, task, team_config))
|
asyncio.create_task(ws_manager.start_stream(run_id, task, team_config))
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Invalid start message format for run {run_id}")
|
logger.warning(f"Invalid start message format for run {run_id}")
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
|
|
|
@ -11,8 +11,12 @@ import {
|
||||||
PanelLeftOpen,
|
PanelLeftOpen,
|
||||||
GalleryHorizontalEnd,
|
GalleryHorizontalEnd,
|
||||||
Rocket,
|
Rocket,
|
||||||
|
Beaker,
|
||||||
|
LucideBeaker,
|
||||||
|
FlaskConical,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import Icon from "./icons";
|
import Icon from "./icons";
|
||||||
|
import { BeakerIcon } from "@heroicons/react/24/outline";
|
||||||
|
|
||||||
interface INavItem {
|
interface INavItem {
|
||||||
name: string;
|
name: string;
|
||||||
|
@ -44,6 +48,12 @@ const navigation: INavItem[] = [
|
||||||
icon: GalleryHorizontalEnd,
|
icon: GalleryHorizontalEnd,
|
||||||
breadcrumbs: [{ name: "Gallery", href: "/gallery", current: true }],
|
breadcrumbs: [{ name: "Gallery", href: "/gallery", current: true }],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Labs",
|
||||||
|
href: "/labs",
|
||||||
|
icon: FlaskConical,
|
||||||
|
breadcrumbs: [{ name: "Labs", href: "/labs", current: true }],
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Deploy",
|
name: "Deploy",
|
||||||
href: "/deploy",
|
href: "/deploy",
|
||||||
|
|
|
@ -42,6 +42,7 @@ export interface FunctionExecutionResult {
|
||||||
export interface BaseMessageConfig {
|
export interface BaseMessageConfig {
|
||||||
source: string;
|
source: string;
|
||||||
models_usage?: RequestUsage;
|
models_usage?: RequestUsage;
|
||||||
|
metadata?: Record<string, string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TextMessageConfig extends BaseMessageConfig {
|
export interface TextMessageConfig extends BaseMessageConfig {
|
||||||
|
@ -373,7 +374,7 @@ export interface Run {
|
||||||
created_at: string;
|
created_at: string;
|
||||||
updated_at?: string;
|
updated_at?: string;
|
||||||
status: RunStatus;
|
status: RunStatus;
|
||||||
task: AgentMessageConfig;
|
task: AgentMessageConfig[];
|
||||||
team_result: TeamResult | null;
|
team_result: TeamResult | null;
|
||||||
messages: Message[];
|
messages: Message[];
|
||||||
error_message?: string;
|
error_message?: string;
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import { RcFile } from "antd/es/upload";
|
||||||
import { IStatus } from "../types/app";
|
import { IStatus } from "../types/app";
|
||||||
|
|
||||||
export const getServerUrl = () => {
|
export const getServerUrl = () => {
|
||||||
|
@ -116,3 +117,24 @@ export const fetchVersion = () => {
|
||||||
return null;
|
return null;
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const convertFilesToBase64 = async (files: RcFile[] = []) => {
|
||||||
|
return Promise.all(
|
||||||
|
files.map(async (file) => {
|
||||||
|
return new Promise<{ name: string; content: string; type: string }>(
|
||||||
|
(resolve, reject) => {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = () => {
|
||||||
|
// Extract base64 content from reader result
|
||||||
|
const base64Content = reader.result as string;
|
||||||
|
// Remove the data URL prefix (e.g., "data:image/png;base64,")
|
||||||
|
const base64Data = base64Content.split(",")[1] || base64Content;
|
||||||
|
resolve({ name: file.name, content: base64Data, type: file.type });
|
||||||
|
};
|
||||||
|
reader.onerror = reject;
|
||||||
|
reader.readAsDataURL(file);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
|
@ -158,11 +158,11 @@ export const TruncatableText = memo(
|
||||||
|
|
||||||
{isFullscreen && (
|
{isFullscreen && (
|
||||||
<div
|
<div
|
||||||
className="fixed inset-0 bg-black/80 z-50 flex items-center justify-center"
|
className="fixed inset-0 dark:bg-black/80 bg-black/10 z-50 flex items-center justify-center"
|
||||||
onClick={() => setIsFullscreen(false)}
|
onClick={() => setIsFullscreen(false)}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
className="relative bg-secondary scroll w-full h-full md:w-4/5 md:h-4/5 md:rounded-lg p-8 overflow-auto"
|
className="relative bg-primary scroll w-full h-full md:w-4/5 md:h-4/5 md:rounded-lg p-8 overflow-auto"
|
||||||
style={{ opacity: 0.95 }}
|
style={{ opacity: 0.95 }}
|
||||||
onClick={(e) => e.stopPropagation()}
|
onClick={(e) => e.stopPropagation()}
|
||||||
>
|
>
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import React, { useState } from "react";
|
import React, { useState, useEffect } from "react";
|
||||||
import { Tabs, Button, Tooltip, Drawer, Input } from "antd";
|
import { Tabs, Button, Tooltip, Drawer, Input } from "antd";
|
||||||
import {
|
import {
|
||||||
Package,
|
Package,
|
||||||
|
@ -12,6 +12,7 @@ import {
|
||||||
Copy,
|
Copy,
|
||||||
Trash,
|
Trash,
|
||||||
Plus,
|
Plus,
|
||||||
|
Download,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { ComponentEditor } from "../teambuilder/builder/component-editor/component-editor";
|
import { ComponentEditor } from "../teambuilder/builder/component-editor/component-editor";
|
||||||
import { TruncatableText } from "../atoms";
|
import { TruncatableText } from "../atoms";
|
||||||
|
@ -160,6 +161,13 @@ export const GalleryDetail: React.FC<{
|
||||||
gallery.config.metadata.description
|
gallery.config.metadata.description
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setTempName(gallery.config.name);
|
||||||
|
setTempDescription(gallery.config.metadata.description);
|
||||||
|
setActiveTab("team");
|
||||||
|
setEditingComponent(null);
|
||||||
|
}, [gallery.id]);
|
||||||
|
|
||||||
const updateGallery = (
|
const updateGallery = (
|
||||||
category: CategoryKey,
|
category: CategoryKey,
|
||||||
updater: (
|
updater: (
|
||||||
|
@ -286,6 +294,21 @@ export const GalleryDetail: React.FC<{
|
||||||
setIsEditingDetails(false);
|
setIsEditingDetails(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleDownload = () => {
|
||||||
|
const dataStr = JSON.stringify(gallery, null, 2);
|
||||||
|
const dataBlob = new Blob([dataStr], { type: "application/json" });
|
||||||
|
const url = URL.createObjectURL(dataBlob);
|
||||||
|
const link = document.createElement("a");
|
||||||
|
link.href = url;
|
||||||
|
link.download = `${gallery.config.name
|
||||||
|
.toLowerCase()
|
||||||
|
.replace(/\s+/g, "_")}.json`;
|
||||||
|
document.body.appendChild(link);
|
||||||
|
link.click();
|
||||||
|
document.body.removeChild(link);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
};
|
||||||
|
|
||||||
const tabItems = Object.entries(iconMap).map(([key, Icon]) => ({
|
const tabItems = Object.entries(iconMap).map(([key, Icon]) => ({
|
||||||
key,
|
key,
|
||||||
label: (
|
label: (
|
||||||
|
@ -355,25 +378,6 @@ export const GalleryDetail: React.FC<{
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
{!isEditingDetails ? (
|
|
||||||
<Button
|
|
||||||
icon={<Edit className="w-4 h-4" />}
|
|
||||||
onClick={() => setIsEditingDetails(true)}
|
|
||||||
type="text"
|
|
||||||
className="text-white hover:text-white/80"
|
|
||||||
>
|
|
||||||
Edit
|
|
||||||
</Button>
|
|
||||||
) : (
|
|
||||||
<div className="flex gap-2">
|
|
||||||
<Button onClick={() => setIsEditingDetails(false)}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button type="primary" onClick={handleDetailsSave}>
|
|
||||||
Save
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
{isEditingDetails ? (
|
{isEditingDetails ? (
|
||||||
<TextArea
|
<TextArea
|
||||||
|
@ -383,9 +387,39 @@ export const GalleryDetail: React.FC<{
|
||||||
rows={2}
|
rows={2}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
<p className="text-secondary w-1/2 mt-2 line-clamp-2">
|
<p className="text-secondary w-1/2 mt-2 line-clamp-2">
|
||||||
{gallery.config.metadata.description}
|
{gallery.config.metadata.description}
|
||||||
</p>
|
</p>
|
||||||
|
<div className="flex gap-0">
|
||||||
|
<Tooltip title="Edit Gallery">
|
||||||
|
<Button
|
||||||
|
icon={<Edit className="w-4 h-4" />}
|
||||||
|
onClick={() => setIsEditingDetails(true)}
|
||||||
|
type="text"
|
||||||
|
className="text-white hover:text-white/80"
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip title="Download Gallery">
|
||||||
|
<Button
|
||||||
|
icon={<Download className="w-4 h-4" />}
|
||||||
|
onClick={handleDownload}
|
||||||
|
type="text"
|
||||||
|
className="text-white hover:text-white/80"
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{isEditingDetails && (
|
||||||
|
<div className="flex gap-2 mt-2">
|
||||||
|
<Button onClick={() => setIsEditingDetails(false)}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button type="primary" onClick={handleDetailsSave}>
|
||||||
|
Save
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="flex gap-2">
|
<div className="flex gap-2">
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
import React from "react";
|
||||||
|
import { Alert } from "antd";
|
||||||
|
import { copyToClipboard } from "./guides";
|
||||||
|
import { Download } from "lucide-react";
|
||||||
|
|
||||||
|
const ComponentLab: React.FC = () => {
|
||||||
|
return (
|
||||||
|
<div className="">
|
||||||
|
<h1 className="tdext-2xl font-bold mb-6">
|
||||||
|
Using AutoGen Studio Teams in Python Code and REST API
|
||||||
|
</h1>
|
||||||
|
|
||||||
|
<Alert
|
||||||
|
className="mb-6"
|
||||||
|
message="Prerequisites"
|
||||||
|
description={
|
||||||
|
<ul className="list-disc pl-4 mt-2 space-y-1">
|
||||||
|
<li>AutoGen Studio installed</li>
|
||||||
|
</ul>
|
||||||
|
}
|
||||||
|
type="info"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ComponentLab;
|
|
@ -0,0 +1,27 @@
|
||||||
|
import React from "react";
|
||||||
|
import { Lab } from "../types";
|
||||||
|
import ComponentLab from "./component";
|
||||||
|
|
||||||
|
interface LabContentProps {
|
||||||
|
lab: Lab;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const copyToClipboard = (text: string) => {
|
||||||
|
navigator.clipboard.writeText(text);
|
||||||
|
};
|
||||||
|
export const LabContent: React.FC<LabContentProps> = ({ lab }) => {
|
||||||
|
// Render different content based on guide type and id
|
||||||
|
switch (lab.id) {
|
||||||
|
case "python-setup":
|
||||||
|
return <ComponentLab />;
|
||||||
|
|
||||||
|
default:
|
||||||
|
return (
|
||||||
|
<div className="text-secondary">
|
||||||
|
A Lab with the title <strong>{lab.title}</strong> is work in progress!
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export default LabContent;
|
|
@ -0,0 +1,87 @@
|
||||||
|
import React, { useState, useEffect } from "react";
|
||||||
|
import { ChevronRight, TriangleAlert } from "lucide-react";
|
||||||
|
import { LabsSidebar } from "./sidebar";
|
||||||
|
import { Lab, defaultLabs } from "./types";
|
||||||
|
import { LabContent } from "./labs/guides";
|
||||||
|
|
||||||
|
export const LabsManager: React.FC = () => {
|
||||||
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
const [labs, setLabs] = useState<Lab[]>([]);
|
||||||
|
const [currentLab, setcurrentLab] = useState<Lab | null>(null);
|
||||||
|
const [isSidebarOpen, setIsSidebarOpen] = useState(() => {
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
const stored = localStorage.getItem("labsSidebar");
|
||||||
|
return stored !== null ? JSON.parse(stored) : true;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Persist sidebar state
|
||||||
|
useEffect(() => {
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
localStorage.setItem("labsSidebar", JSON.stringify(isSidebarOpen));
|
||||||
|
}
|
||||||
|
}, [isSidebarOpen]);
|
||||||
|
|
||||||
|
// Set first guide as current if none selected
|
||||||
|
useEffect(() => {
|
||||||
|
if (!currentLab && labs.length > 0) {
|
||||||
|
setcurrentLab(labs[0]);
|
||||||
|
}
|
||||||
|
}, [labs, currentLab]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="relative flex h-full w-full">
|
||||||
|
{/* Sidebar */}
|
||||||
|
<div
|
||||||
|
className={`absolute left-0 top-0 h-full transition-all duration-200 ease-in-out ${
|
||||||
|
isSidebarOpen ? "w-64" : "w-12"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<LabsSidebar
|
||||||
|
isOpen={isSidebarOpen}
|
||||||
|
labs={labs}
|
||||||
|
currentLab={currentLab}
|
||||||
|
onToggle={() => setIsSidebarOpen(!isSidebarOpen)}
|
||||||
|
onSelectLab={setcurrentLab}
|
||||||
|
isLoading={isLoading}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Main Content */}
|
||||||
|
<div
|
||||||
|
className={`flex-1 transition-all max-w-5xl -mr-6 duration-200 ${
|
||||||
|
isSidebarOpen ? "ml-64" : "ml-12"
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<div className="p-4 pt-2">
|
||||||
|
{/* Breadcrumb */}
|
||||||
|
<div className="flex items-center gap-2 mb-4 text-sm">
|
||||||
|
<span className="text-primary font-medium">Labs</span>
|
||||||
|
{currentLab && (
|
||||||
|
<>
|
||||||
|
<ChevronRight className="w-4 h-4 text-secondary" />
|
||||||
|
<span className="text-secondary">{currentLab.title}</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="rounded border border-secondary border-dashed p-2 text-sm mb-4">
|
||||||
|
<TriangleAlert className="w-4 h-4 inline-block mr-2 -mt-1 text-secondary " />{" "}
|
||||||
|
Labs is designed to host experimental features for building and
|
||||||
|
debugging multiagent applications.
|
||||||
|
</div>
|
||||||
|
{/* Content Area */}
|
||||||
|
{currentLab ? (
|
||||||
|
<LabContent lab={currentLab} />
|
||||||
|
) : (
|
||||||
|
<div className="flex items-center justify-center h-[calc(100vh-190px)] text-secondary">
|
||||||
|
Select a lab from the sidebar to get started
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default LabsManager;
|
|
@ -0,0 +1,111 @@
|
||||||
|
import React from "react";
|
||||||
|
import { Button, Tooltip } from "antd";
|
||||||
|
import {
|
||||||
|
PanelLeftClose,
|
||||||
|
PanelLeftOpen,
|
||||||
|
Book,
|
||||||
|
InfoIcon,
|
||||||
|
RefreshCcw,
|
||||||
|
} from "lucide-react";
|
||||||
|
import type { Lab } from "./types";
|
||||||
|
|
||||||
|
interface LabsSidebarProps {
|
||||||
|
isOpen: boolean;
|
||||||
|
labs: Lab[];
|
||||||
|
currentLab: Lab | null;
|
||||||
|
onToggle: () => void;
|
||||||
|
onSelectLab: (guide: Lab) => void;
|
||||||
|
isLoading?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const LabsSidebar: React.FC<LabsSidebarProps> = ({
|
||||||
|
isOpen,
|
||||||
|
labs,
|
||||||
|
currentLab,
|
||||||
|
onToggle,
|
||||||
|
onSelectLab,
|
||||||
|
isLoading = false,
|
||||||
|
}) => {
|
||||||
|
// Render collapsed state
|
||||||
|
if (!isOpen) {
|
||||||
|
return (
|
||||||
|
<div className="h-full border-r border-secondary">
|
||||||
|
<div className="p-2 -ml-2">
|
||||||
|
<Tooltip title="Documentation">
|
||||||
|
<button
|
||||||
|
onClick={onToggle}
|
||||||
|
className="p-2 rounded-md hover:bg-secondary hover:text-accent text-secondary transition-colors focus:outline-none focus:ring-2 focus:ring-accent focus:ring-opacity-50"
|
||||||
|
>
|
||||||
|
<PanelLeftOpen strokeWidth={1.5} className="h-6 w-6" />
|
||||||
|
</button>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="h-full border-r border-secondary">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-center justify-between pt-0 p-4 pl-2 pr-2 border-b border-secondary">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
{/* <Book className="w-4 h-4" /> */}
|
||||||
|
<span className="text-primary font-medium">Labs</span>
|
||||||
|
{/* <span className="px-2 py-0.5 text-xs bg-accent/10 text-accent rounded">
|
||||||
|
{guides.length}
|
||||||
|
</span> */}
|
||||||
|
</div>
|
||||||
|
<Tooltip title="Close Sidebar">
|
||||||
|
<button
|
||||||
|
onClick={onToggle}
|
||||||
|
className="p-2 rounded-md hover:bg-secondary hover:text-accent text-secondary transition-colors focus:outline-none focus:ring-2 focus:ring-accent focus:ring-opacity-50"
|
||||||
|
>
|
||||||
|
<PanelLeftClose strokeWidth={1.5} className="h-6 w-6" />
|
||||||
|
</button>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Loading State */}
|
||||||
|
{isLoading && (
|
||||||
|
<div className="p-4">
|
||||||
|
<RefreshCcw className="w-4 h-4 inline-block animate-spin" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Empty State */}
|
||||||
|
{!isLoading && labs.length === 0 && (
|
||||||
|
<div className="p-2 mt-2 mr-2 text-center text-secondary text-sm border border-dashed rounded">
|
||||||
|
<InfoIcon className="w-4 h-4 inline-block mr-1.5 -mt-0.5" />
|
||||||
|
No labs available. Please check back later.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Guides List */}
|
||||||
|
<div className="overflow-y-auto h-[calc(100%-64px)] mt-4">
|
||||||
|
{labs.map((lab) => (
|
||||||
|
<div key={lab.id} className="relative">
|
||||||
|
<div
|
||||||
|
className={`absolute top-1 left-0.5 z-50 h-[calc(100%-8px)]
|
||||||
|
w-1 bg-opacity-80 rounded ${
|
||||||
|
currentLab?.id === lab.id ? "bg-accent" : "bg-tertiary"
|
||||||
|
}`}
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
className={`group ml-1 flex flex-col p-2 rounded-l cursor-pointer hover:bg-secondary ${
|
||||||
|
currentLab?.id === lab.id
|
||||||
|
? "border-accent bg-secondary"
|
||||||
|
: "border-transparent"
|
||||||
|
}`}
|
||||||
|
onClick={() => onSelectLab(lab)}
|
||||||
|
>
|
||||||
|
{/* Guide Title */}
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<span className="text-sm truncate">{lab.title}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
|
@ -0,0 +1,13 @@
|
||||||
|
export interface Lab {
|
||||||
|
id: string;
|
||||||
|
title: string;
|
||||||
|
type: "python" | "docker" | "cloud";
|
||||||
|
}
|
||||||
|
|
||||||
|
export const defaultLabs: Lab[] = [
|
||||||
|
{
|
||||||
|
id: "component-builder",
|
||||||
|
title: "Component Builder",
|
||||||
|
type: "python",
|
||||||
|
},
|
||||||
|
];
|
|
@ -1,6 +1,6 @@
|
||||||
import * as React from "react";
|
import * as React from "react";
|
||||||
import { Button, message, Tooltip } from "antd";
|
import { Button, message, Tooltip } from "antd";
|
||||||
import { getServerUrl } from "../../../utils/utils";
|
import { convertFilesToBase64, getServerUrl } from "../../../utils/utils";
|
||||||
import { IStatus } from "../../../types/app";
|
import { IStatus } from "../../../types/app";
|
||||||
import {
|
import {
|
||||||
Run,
|
Run,
|
||||||
|
@ -27,6 +27,7 @@ import {
|
||||||
X,
|
X,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import SessionDropdown from "./sessiondropdown";
|
import SessionDropdown from "./sessiondropdown";
|
||||||
|
import { RcFile } from "antd/es/upload";
|
||||||
const logo = require("../../../../images/landing/welcome.svg").default;
|
const logo = require("../../../../images/landing/welcome.svg").default;
|
||||||
|
|
||||||
interface ChatViewProps {
|
interface ChatViewProps {
|
||||||
|
@ -395,7 +396,7 @@ export default function ChatView({
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const runTask = async (query: string) => {
|
const runTask = async (query: string, files: RcFile[] = []) => {
|
||||||
setError(null);
|
setError(null);
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
|
|
||||||
|
@ -405,13 +406,13 @@ export default function ChatView({
|
||||||
setActiveSocket(null);
|
setActiveSocket(null);
|
||||||
activeSocketRef.current = null;
|
activeSocketRef.current = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inputTimeoutRef.current) {
|
if (inputTimeoutRef.current) {
|
||||||
clearTimeout(inputTimeoutRef.current);
|
clearTimeout(inputTimeoutRef.current);
|
||||||
inputTimeoutRef.current = null;
|
inputTimeoutRef.current = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!session?.id || !teamConfig) {
|
if (!session?.id || !teamConfig) {
|
||||||
// Add teamConfig check
|
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -419,6 +420,9 @@ export default function ChatView({
|
||||||
try {
|
try {
|
||||||
const runId = await createRun(session.id);
|
const runId = await createRun(session.id);
|
||||||
|
|
||||||
|
// Process files using the extracted function
|
||||||
|
const processedFiles = await convertFilesToBase64(files);
|
||||||
|
|
||||||
// Initialize run state BEFORE websocket connection
|
// Initialize run state BEFORE websocket connection
|
||||||
setCurrentRun({
|
setCurrentRun({
|
||||||
id: runId,
|
id: runId,
|
||||||
|
@ -433,8 +437,8 @@ export default function ChatView({
|
||||||
error_message: undefined,
|
error_message: undefined,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Setup WebSocket
|
// Setup WebSocket with files
|
||||||
const socket = setupWebSocket(runId, query);
|
const socket = setupWebSocket(runId, query, processedFiles);
|
||||||
setActiveSocket(socket);
|
setActiveSocket(socket);
|
||||||
activeSocketRef.current = socket;
|
activeSocketRef.current = socket;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
@ -444,7 +448,11 @@ export default function ChatView({
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const setupWebSocket = (runId: number, query: string): WebSocket => {
|
const setupWebSocket = (
|
||||||
|
runId: number,
|
||||||
|
query: string,
|
||||||
|
files: { name: string; type: string; content: string }[]
|
||||||
|
): WebSocket => {
|
||||||
if (!session || !session.id) {
|
if (!session || !session.id) {
|
||||||
throw new Error("Invalid session configuration");
|
throw new Error("Invalid session configuration");
|
||||||
}
|
}
|
||||||
|
@ -465,6 +473,7 @@ export default function ChatView({
|
||||||
id: runId,
|
id: runId,
|
||||||
created_at: new Date().toISOString(),
|
created_at: new Date().toISOString(),
|
||||||
status: "active",
|
status: "active",
|
||||||
|
|
||||||
task: createMessage(
|
task: createMessage(
|
||||||
{ content: query, source: "user" },
|
{ content: query, source: "user" },
|
||||||
runId,
|
runId,
|
||||||
|
@ -481,6 +490,7 @@ export default function ChatView({
|
||||||
JSON.stringify({
|
JSON.stringify({
|
||||||
type: "start",
|
type: "start",
|
||||||
task: query,
|
task: query,
|
||||||
|
files: files,
|
||||||
team_config: teamConfig,
|
team_config: teamConfig,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
@ -657,7 +667,10 @@ export default function ChatView({
|
||||||
onSubmit={runTask}
|
onSubmit={runTask}
|
||||||
loading={loading}
|
loading={loading}
|
||||||
error={error}
|
error={error}
|
||||||
disabled={currentRun?.status === "awaiting_input"}
|
disabled={
|
||||||
|
currentRun?.status === "awaiting_input" ||
|
||||||
|
currentRun?.status === "active"
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
"use client";
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
PaperAirplaneIcon,
|
PaperAirplaneIcon,
|
||||||
Cog6ToothIcon,
|
Cog6ToothIcon,
|
||||||
|
@ -7,9 +5,33 @@ import {
|
||||||
} from "@heroicons/react/24/outline";
|
} from "@heroicons/react/24/outline";
|
||||||
import * as React from "react";
|
import * as React from "react";
|
||||||
import { IStatus } from "../../../types/app";
|
import { IStatus } from "../../../types/app";
|
||||||
|
import { Upload, message, Button, Tooltip, notification } from "antd";
|
||||||
|
import type { UploadFile, UploadProps, RcFile } from "antd/es/upload/interface";
|
||||||
|
import {
|
||||||
|
FileTextIcon,
|
||||||
|
ImageIcon,
|
||||||
|
Paperclip,
|
||||||
|
UploadIcon,
|
||||||
|
XIcon,
|
||||||
|
} from "lucide-react";
|
||||||
|
import { truncateText } from "../../../utils/utils";
|
||||||
|
|
||||||
|
// Maximum file size in bytes (5MB)
|
||||||
|
const MAX_FILE_SIZE = 5 * 1024 * 1024;
|
||||||
|
// Allowed file types
|
||||||
|
const ALLOWED_FILE_TYPES = [
|
||||||
|
"text/plain",
|
||||||
|
"image/jpeg",
|
||||||
|
"image/png",
|
||||||
|
"image/gif",
|
||||||
|
"image/svg+xml",
|
||||||
|
];
|
||||||
|
|
||||||
|
// Threshold for large text files (in characters)
|
||||||
|
const LARGE_TEXT_THRESHOLD = 1500;
|
||||||
|
|
||||||
interface ChatInputProps {
|
interface ChatInputProps {
|
||||||
onSubmit: (text: string) => void;
|
onSubmit: (text: string, files: RcFile[]) => void;
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
error: IStatus | null;
|
error: IStatus | null;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
|
@ -23,7 +45,11 @@ export default function ChatInput({
|
||||||
}: ChatInputProps) {
|
}: ChatInputProps) {
|
||||||
const textAreaRef = React.useRef<HTMLTextAreaElement>(null);
|
const textAreaRef = React.useRef<HTMLTextAreaElement>(null);
|
||||||
const [previousLoading, setPreviousLoading] = React.useState(loading);
|
const [previousLoading, setPreviousLoading] = React.useState(loading);
|
||||||
const [text, setText] = React.useState("");
|
const [text, setText] = React.useState("What is the capital of France?");
|
||||||
|
const [fileList, setFileList] = React.useState<UploadFile[]>([]);
|
||||||
|
const [dragOver, setDragOver] = React.useState(false);
|
||||||
|
const [notificationApi, notificationContextHolder] =
|
||||||
|
notification.useNotification();
|
||||||
|
|
||||||
const textAreaDefaultHeight = "64px";
|
const textAreaDefaultHeight = "64px";
|
||||||
const isInputDisabled = disabled || loading;
|
const isInputDisabled = disabled || loading;
|
||||||
|
@ -31,7 +57,9 @@ export default function ChatInput({
|
||||||
// Handle textarea auto-resize
|
// Handle textarea auto-resize
|
||||||
React.useEffect(() => {
|
React.useEffect(() => {
|
||||||
if (textAreaRef.current) {
|
if (textAreaRef.current) {
|
||||||
textAreaRef.current.style.height = textAreaDefaultHeight;
|
// Temporarily set height to auto to get proper scrollHeight
|
||||||
|
textAreaRef.current.style.height = "auto";
|
||||||
|
// Then set to the scroll height
|
||||||
const scrollHeight = textAreaRef.current.scrollHeight;
|
const scrollHeight = textAreaRef.current.scrollHeight;
|
||||||
textAreaRef.current.style.height = `${scrollHeight}px`;
|
textAreaRef.current.style.height = `${scrollHeight}px`;
|
||||||
}
|
}
|
||||||
|
@ -45,11 +73,139 @@ export default function ChatInput({
|
||||||
setPreviousLoading(loading);
|
setPreviousLoading(loading);
|
||||||
}, [loading, error, previousLoading]);
|
}, [loading, error, previousLoading]);
|
||||||
|
|
||||||
|
// Add paste event listener
|
||||||
|
React.useEffect(() => {
|
||||||
|
const handlePaste = (e: ClipboardEvent) => {
|
||||||
|
if (isInputDisabled) return;
|
||||||
|
|
||||||
|
// Handle image paste
|
||||||
|
if (e.clipboardData?.items) {
|
||||||
|
let hasImageItem = false;
|
||||||
|
|
||||||
|
for (let i = 0; i < e.clipboardData.items.length; i++) {
|
||||||
|
const item = e.clipboardData.items[i];
|
||||||
|
|
||||||
|
// Handle image items
|
||||||
|
if (item.type.indexOf("image/") === 0) {
|
||||||
|
hasImageItem = true;
|
||||||
|
const file = item.getAsFile();
|
||||||
|
|
||||||
|
if (file && file.size <= MAX_FILE_SIZE) {
|
||||||
|
// Prevent the default paste behavior for images
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
// Create a unique file name
|
||||||
|
const fileName = `pasted-image-${new Date().getTime()}.png`;
|
||||||
|
|
||||||
|
// Create a new File with a proper name
|
||||||
|
const namedFile = new File([file], fileName, { type: file.type });
|
||||||
|
|
||||||
|
// Convert to the expected UploadFile format
|
||||||
|
const uploadFile: UploadFile = {
|
||||||
|
uid: `paste-${Date.now()}`,
|
||||||
|
name: fileName,
|
||||||
|
status: "done",
|
||||||
|
size: namedFile.size,
|
||||||
|
type: namedFile.type,
|
||||||
|
originFileObj: namedFile as RcFile,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add to file list
|
||||||
|
setFileList((prev) => [...prev, uploadFile]);
|
||||||
|
|
||||||
|
// Show successful paste notification
|
||||||
|
message.success(`Image pasted successfully`);
|
||||||
|
} else if (file && file.size > MAX_FILE_SIZE) {
|
||||||
|
message.error(`Pasted image is too large. Maximum size is 5MB.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle text items - only if there's a large amount of text
|
||||||
|
if (item.type === "text/plain" && !hasImageItem) {
|
||||||
|
item.getAsString((text) => {
|
||||||
|
// Only process for large text
|
||||||
|
if (text.length > LARGE_TEXT_THRESHOLD) {
|
||||||
|
// We need to prevent the default paste behavior
|
||||||
|
// But since we're in an async callback, we need to
|
||||||
|
// manually clear the textarea's selection value
|
||||||
|
setTimeout(() => {
|
||||||
|
if (textAreaRef.current) {
|
||||||
|
const currentValue = textAreaRef.current.value;
|
||||||
|
const selectionStart =
|
||||||
|
textAreaRef.current.selectionStart || 0;
|
||||||
|
const selectionEnd = textAreaRef.current.selectionEnd || 0;
|
||||||
|
|
||||||
|
// Remove the pasted text from the textarea
|
||||||
|
const newValue =
|
||||||
|
currentValue.substring(0, selectionStart - text.length) +
|
||||||
|
currentValue.substring(selectionEnd);
|
||||||
|
|
||||||
|
// Update the textarea
|
||||||
|
textAreaRef.current.value = newValue;
|
||||||
|
// Trigger the onChange event manually
|
||||||
|
setText(newValue);
|
||||||
|
}
|
||||||
|
}, 0);
|
||||||
|
|
||||||
|
// Prevent default paste for large text
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
// Create a text file from the pasted content
|
||||||
|
const blob = new Blob([text], { type: "text/plain" });
|
||||||
|
const file = new File(
|
||||||
|
[blob],
|
||||||
|
`pasted-text-${new Date().getTime()}.txt`,
|
||||||
|
{ type: "text/plain" }
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add to file list
|
||||||
|
const uploadFile: UploadFile = {
|
||||||
|
uid: `paste-${Date.now()}`,
|
||||||
|
name: file.name,
|
||||||
|
status: "done",
|
||||||
|
size: file.size,
|
||||||
|
type: file.type,
|
||||||
|
originFileObj: file as RcFile,
|
||||||
|
};
|
||||||
|
|
||||||
|
setFileList((prev) => [...prev, uploadFile]);
|
||||||
|
|
||||||
|
// Notify user about the conversion
|
||||||
|
notificationApi.info({
|
||||||
|
message: (
|
||||||
|
<span className="text-sm">
|
||||||
|
Large Text Converted to File
|
||||||
|
</span>
|
||||||
|
),
|
||||||
|
description: (
|
||||||
|
<span className="text-sm text-secondary">
|
||||||
|
Your pasted text has been attached as a file.
|
||||||
|
</span>
|
||||||
|
),
|
||||||
|
duration: 3,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add the paste event listener to the document
|
||||||
|
document.addEventListener("paste", handlePaste);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
return () => {
|
||||||
|
document.removeEventListener("paste", handlePaste);
|
||||||
|
};
|
||||||
|
}, [isInputDisabled, notificationApi]);
|
||||||
|
|
||||||
const resetInput = () => {
|
const resetInput = () => {
|
||||||
if (textAreaRef.current) {
|
if (textAreaRef.current) {
|
||||||
textAreaRef.current.value = "";
|
textAreaRef.current.value = "";
|
||||||
textAreaRef.current.style.height = textAreaDefaultHeight;
|
textAreaRef.current.style.height = textAreaDefaultHeight;
|
||||||
setText("");
|
setText("");
|
||||||
|
setFileList([]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -58,9 +214,18 @@ export default function ChatInput({
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSubmit = () => {
|
const handleSubmit = () => {
|
||||||
if (textAreaRef.current?.value && !isInputDisabled) {
|
if (
|
||||||
const query = textAreaRef.current.value;
|
(textAreaRef.current?.value || fileList.length > 0) &&
|
||||||
onSubmit(query);
|
!isInputDisabled
|
||||||
|
) {
|
||||||
|
const query = textAreaRef.current?.value || "";
|
||||||
|
|
||||||
|
// Get all valid RcFile objects
|
||||||
|
const files = fileList
|
||||||
|
.filter((file) => file.originFileObj)
|
||||||
|
.map((file) => file.originFileObj as RcFile);
|
||||||
|
|
||||||
|
onSubmit(query, files);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,12 +236,161 @@ export default function ChatInput({
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const uploadProps: UploadProps = {
|
||||||
|
name: "file",
|
||||||
|
multiple: true,
|
||||||
|
fileList,
|
||||||
|
beforeUpload: (file) => {
|
||||||
|
// Check file size
|
||||||
|
if (file.size > MAX_FILE_SIZE) {
|
||||||
|
message.error(`${file.name} is too large. Maximum size is 5MB.`);
|
||||||
|
return Upload.LIST_IGNORE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file type
|
||||||
|
if (!ALLOWED_FILE_TYPES.includes(file.type)) {
|
||||||
|
notificationApi.warning({
|
||||||
|
message: <span className="text-sm">Unsupported File Type</span>,
|
||||||
|
description: (
|
||||||
|
<span className="text-sm text-secondary">
|
||||||
|
Please upload only text (.txt) or images (.jpg, .png, .gif, .svg)
|
||||||
|
files.
|
||||||
|
</span>
|
||||||
|
),
|
||||||
|
showProgress: true,
|
||||||
|
duration: 8.5,
|
||||||
|
});
|
||||||
|
return Upload.LIST_IGNORE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Correctly set the uploadFile with originFileObj property
|
||||||
|
const uploadFile: UploadFile = {
|
||||||
|
uid: file.uid,
|
||||||
|
name: file.name,
|
||||||
|
status: "done",
|
||||||
|
size: file.size,
|
||||||
|
type: file.type,
|
||||||
|
originFileObj: file,
|
||||||
|
};
|
||||||
|
|
||||||
|
setFileList((prev) => [...prev, uploadFile]);
|
||||||
|
return false; // Prevent automatic upload
|
||||||
|
},
|
||||||
|
onRemove: (file) => {
|
||||||
|
setFileList(fileList.filter((item) => item.uid !== file.uid));
|
||||||
|
},
|
||||||
|
showUploadList: false, // We'll handle our own custom file preview
|
||||||
|
customRequest: ({ onSuccess }) => {
|
||||||
|
// Mock successful upload since we're not actually uploading anywhere yet
|
||||||
|
if (onSuccess) onSuccess("ok");
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const getFileIcon = (file: UploadFile) => {
|
||||||
|
const fileType = file.type || "";
|
||||||
|
if (fileType.startsWith("image/")) {
|
||||||
|
return <ImageIcon className="w-4 h-4" />;
|
||||||
|
}
|
||||||
|
return <FileTextIcon className="w-4 h-4" />;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add these new event handler functions to the component
|
||||||
|
const handleDragOver = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
if (!isInputDisabled) {
|
||||||
|
setDragOver(true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDragLeave = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
setDragOver(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDrop = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
setDragOver(false);
|
||||||
|
|
||||||
|
if (isInputDisabled) return;
|
||||||
|
|
||||||
|
const droppedFiles = Array.from(e.dataTransfer.files);
|
||||||
|
|
||||||
|
droppedFiles.forEach((file) => {
|
||||||
|
// Check file size
|
||||||
|
if (file.size > MAX_FILE_SIZE) {
|
||||||
|
message.error(`${file.name} is too large. Maximum size is 5MB.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check file type
|
||||||
|
if (!ALLOWED_FILE_TYPES.includes(file.type)) {
|
||||||
|
notificationApi.warning({
|
||||||
|
message: <span className="text-sm">Unsupported File Type</span>,
|
||||||
|
description: (
|
||||||
|
<span className="text-sm text-secondary">
|
||||||
|
Please upload only text (.txt) or images (.jpg, .png, .gif, .svg)
|
||||||
|
files.
|
||||||
|
</span>
|
||||||
|
),
|
||||||
|
showProgress: true,
|
||||||
|
duration: 8.5,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to file list
|
||||||
|
const uploadFile: UploadFile = {
|
||||||
|
uid: `file-${Date.now()}-${file.name}`,
|
||||||
|
name: file.name,
|
||||||
|
status: "done",
|
||||||
|
size: file.size,
|
||||||
|
type: file.type,
|
||||||
|
originFileObj: file as RcFile,
|
||||||
|
};
|
||||||
|
|
||||||
|
setFileList((prev) => [...prev, uploadFile]);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="mt-2 w-full">
|
<div className="mt-2 w-full">
|
||||||
|
{notificationContextHolder}
|
||||||
|
{/* File previews */}
|
||||||
|
{fileList.length > 0 && (
|
||||||
|
<div className="-mb-2 mx-1 bg-tertiary rounded-t border-b-0 p-2 flex bodrder flex-wrap gap-2">
|
||||||
|
{fileList.map((file) => (
|
||||||
<div
|
<div
|
||||||
className={`mt-2 rounded shadow-sm flex mb-1 ${
|
key={file.uid}
|
||||||
isInputDisabled ? "opacity-50" : ""
|
className="flex items-center gap-1 bg-secondary rounded px-2 py-1 text-xs"
|
||||||
}`}
|
>
|
||||||
|
{getFileIcon(file)}
|
||||||
|
<span className="truncate max-w-[150px]">
|
||||||
|
{truncateText(file.name, 20)}
|
||||||
|
</span>
|
||||||
|
<Button
|
||||||
|
type="text"
|
||||||
|
size="small"
|
||||||
|
className="p-0 ml-1 flex items-center justify-center"
|
||||||
|
onClick={() =>
|
||||||
|
setFileList((prev) => prev.filter((f) => f.uid !== file.uid))
|
||||||
|
}
|
||||||
|
icon={<XIcon className="w-3 h-3" />}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div
|
||||||
|
className={`mt-2 rounded shadow-sm flex mb-1 transition-all duration-200 ${
|
||||||
|
dragOver ? "ring-2 ring-blue-400" : ""
|
||||||
|
} ${isInputDisabled ? "opacity-50" : ""}`}
|
||||||
|
onDragOver={handleDragOver}
|
||||||
|
onDragLeave={handleDragLeave}
|
||||||
|
onDrop={handleDrop}
|
||||||
>
|
>
|
||||||
<form
|
<form
|
||||||
className="flex-1 relative"
|
className="flex-1 relative"
|
||||||
|
@ -85,30 +399,75 @@ export default function ChatInput({
|
||||||
handleSubmit();
|
handleSubmit();
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
{dragOver && (
|
||||||
|
<div className="absolute inset-0 bg-blue-100 bg-opacity-60 flex items-center justify-center rounded z-10 pointer-events-none">
|
||||||
|
<div className="text-accent tex-xs items-center">
|
||||||
|
<UploadIcon className="h-4 w-4 mr-2 inline-block" />
|
||||||
|
<span className="text-xs inline-block">Drop files here</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<textarea
|
<textarea
|
||||||
id="queryInput"
|
id="queryInput"
|
||||||
name="queryInput"
|
name="queryInput"
|
||||||
ref={textAreaRef}
|
ref={textAreaRef}
|
||||||
defaultValue={"what is the height of the eiffel tower"}
|
value={text}
|
||||||
onChange={handleTextChange}
|
onChange={handleTextChange}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
className={`flex items-center w-full resize-none text-gray-600 rounded border border-accent bg-white p-2 pl-5 pr-16 ${
|
className={`flex items-center w-full resize-none text-gray-600 rounded ${
|
||||||
isInputDisabled ? "cursor-not-allowed" : ""
|
dragOver
|
||||||
}`}
|
? "border-2 border-blue-500 bg-blue-50"
|
||||||
|
: "border border-accent bg-white"
|
||||||
|
} p-2 pl-5 pr-16 ${isInputDisabled ? "cursor-not-allowed" : ""}`}
|
||||||
style={{
|
style={{
|
||||||
maxHeight: "120px",
|
maxHeight: "120px",
|
||||||
overflowY: "auto",
|
overflowY: "auto",
|
||||||
minHeight: "50px",
|
minHeight: "50px",
|
||||||
|
transition: "all 0.2s ease-in-out",
|
||||||
}}
|
}}
|
||||||
placeholder="Type your message here..."
|
placeholder={
|
||||||
|
dragOver ? "Drop files here..." : "Type your message here..."
|
||||||
|
}
|
||||||
disabled={isInputDisabled}
|
disabled={isInputDisabled}
|
||||||
/>
|
/>
|
||||||
|
<div className={`absolute right-3 bottom-2 flex gap-2`}>
|
||||||
|
<div
|
||||||
|
className={` ${
|
||||||
|
disabled || isInputDisabled
|
||||||
|
? " opacity-50 pointer-events-none "
|
||||||
|
: ""
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{" "}
|
||||||
|
<Upload className="zero-padding-upload " {...uploadProps}>
|
||||||
|
<Tooltip
|
||||||
|
title=<span className="text-sm">
|
||||||
|
Upload File{" "}
|
||||||
|
<span className="text-secondary text-xs">(max 5mb)</span>
|
||||||
|
</span>
|
||||||
|
placement="top"
|
||||||
|
>
|
||||||
|
<Button type="text" disabled={isInputDisabled} className=" ">
|
||||||
|
<UploadIcon
|
||||||
|
strokeWidth={2}
|
||||||
|
size={26}
|
||||||
|
className="p-1 inline-block w-8 text-accent"
|
||||||
|
/>
|
||||||
|
</Button>
|
||||||
|
</Tooltip>
|
||||||
|
</Upload>
|
||||||
|
</div>
|
||||||
|
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={handleSubmit}
|
onClick={handleSubmit}
|
||||||
disabled={isInputDisabled}
|
disabled={
|
||||||
className={`absolute right-3 bottom-2 bg-accent transition duration-300 rounded flex justify-center items-center w-11 h-9 ${
|
isInputDisabled || (text.trim() === "" && fileList.length === 0)
|
||||||
isInputDisabled ? "cursor-not-allowed" : "hover:brightness-75"
|
}
|
||||||
|
className={`bg-accent transition duration-300 rounded flex justify-center items-center w-11 h-9 ${
|
||||||
|
isInputDisabled || (text.trim() === "" && fileList.length === 0)
|
||||||
|
? "cursor-not-allowed opacity-50"
|
||||||
|
: "hover:brightness-75"
|
||||||
}`}
|
}`}
|
||||||
>
|
>
|
||||||
{loading ? (
|
{loading ? (
|
||||||
|
@ -117,6 +476,7 @@ export default function ChatInput({
|
||||||
<PaperAirplaneIcon className="h-6 w-6 text-white" />
|
<PaperAirplaneIcon className="h-6 w-6 text-white" />
|
||||||
)}
|
)}
|
||||||
</button>
|
</button>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
|
@ -25,9 +25,10 @@ const getImageSource = (item: ImageContent): string => {
|
||||||
return "/api/placeholder/400/320";
|
return "/api/placeholder/400/320";
|
||||||
};
|
};
|
||||||
|
|
||||||
const RenderMultiModal: React.FC<{ content: (string | ImageContent)[] }> = ({
|
const RenderMultiModal: React.FC<{
|
||||||
content,
|
content: (string | ImageContent)[];
|
||||||
}) => (
|
thumbnail?: boolean;
|
||||||
|
}> = ({ content, thumbnail = false }) => (
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
{content.map((item, index) =>
|
{content.map((item, index) =>
|
||||||
typeof item === "string" ? (
|
typeof item === "string" ? (
|
||||||
|
@ -37,7 +38,9 @@ const RenderMultiModal: React.FC<{ content: (string | ImageContent)[] }> = ({
|
||||||
key={index}
|
key={index}
|
||||||
src={getImageSource(item)}
|
src={getImageSource(item)}
|
||||||
alt={item.alt || "Image"}
|
alt={item.alt || "Image"}
|
||||||
className="w-full h-auto rounded border border-secondary"
|
className={` h-auto rounded border border-secondary ${
|
||||||
|
thumbnail ? "w-24 h-24 " : " w-full "
|
||||||
|
}`}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
)}
|
)}
|
||||||
|
@ -101,6 +104,18 @@ export const messageUtils = {
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
|
||||||
|
isNestedMessageContent(content: unknown): content is AgentMessageConfig[] {
|
||||||
|
if (!Array.isArray(content)) return false;
|
||||||
|
return content.every(
|
||||||
|
(item) =>
|
||||||
|
typeof item === "object" &&
|
||||||
|
item !== null &&
|
||||||
|
"source" in item &&
|
||||||
|
"content" in item &&
|
||||||
|
"type" in item
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
isMultiModalContent(content: unknown): content is (string | ImageContent)[] {
|
isMultiModalContent(content: unknown): content is (string | ImageContent)[] {
|
||||||
if (!Array.isArray(content)) return false;
|
if (!Array.isArray(content)) return false;
|
||||||
return content.every(
|
return content.every(
|
||||||
|
@ -128,20 +143,66 @@ export const messageUtils = {
|
||||||
isUser(source: string): boolean {
|
isUser(source: string): boolean {
|
||||||
return source === "user";
|
return source === "user";
|
||||||
},
|
},
|
||||||
|
|
||||||
|
isMessageArray(
|
||||||
|
message: AgentMessageConfig | AgentMessageConfig[]
|
||||||
|
): message is AgentMessageConfig[] {
|
||||||
|
return Array.isArray(message);
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
interface MessageProps {
|
interface MessageProps {
|
||||||
message: AgentMessageConfig;
|
message: AgentMessageConfig | AgentMessageConfig[];
|
||||||
isLast?: boolean;
|
isLast?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const RenderNestedMessages: React.FC<{
|
||||||
|
content: AgentMessageConfig[];
|
||||||
|
}> = ({ content }) => (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{content.map((item, index) => (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className={`${
|
||||||
|
index > 0 ? "bordper border-secondary rounded bg-secondary/30" : ""
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{typeof item.content === "string" ? (
|
||||||
|
<TruncatableText
|
||||||
|
content={item.content}
|
||||||
|
className={`break-all ${index === 0 ? "text-base" : "text-sm"}`}
|
||||||
|
/>
|
||||||
|
) : messageUtils.isMultiModalContent(item.content) ? (
|
||||||
|
<RenderMultiModal content={item.content} thumbnail />
|
||||||
|
) : (
|
||||||
|
<pre className="text-xs whitespace-pre-wrap overflow-x-auto">
|
||||||
|
{JSON.stringify(item.content, null, 2)}
|
||||||
|
</pre>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
export const RenderMessage: React.FC<MessageProps> = ({
|
export const RenderMessage: React.FC<MessageProps> = ({
|
||||||
message,
|
message,
|
||||||
isLast = false,
|
isLast = false,
|
||||||
className = "",
|
className = "",
|
||||||
}) => {
|
}) => {
|
||||||
if (!message) return null;
|
if (!message) return null;
|
||||||
|
|
||||||
|
// If message is an array, render the first message or return null
|
||||||
|
if (messageUtils.isMessageArray(message)) {
|
||||||
|
return message.length > 0 ? (
|
||||||
|
<RenderMessage
|
||||||
|
message={message[0]}
|
||||||
|
isLast={isLast}
|
||||||
|
className={className}
|
||||||
|
/>
|
||||||
|
) : null;
|
||||||
|
}
|
||||||
|
|
||||||
const isUser = messageUtils.isUser(message.source);
|
const isUser = messageUtils.isUser(message.source);
|
||||||
const content = message.content;
|
const content = message.content;
|
||||||
const isLLMEventMessage = message.source === "llm_call_event";
|
const isLLMEventMessage = message.source === "llm_call_event";
|
||||||
|
@ -186,7 +247,9 @@ export const RenderMessage: React.FC<MessageProps> = ({
|
||||||
{messageUtils.isToolCallContent(content) ? (
|
{messageUtils.isToolCallContent(content) ? (
|
||||||
<RenderToolCall content={content} />
|
<RenderToolCall content={content} />
|
||||||
) : messageUtils.isMultiModalContent(content) ? (
|
) : messageUtils.isMultiModalContent(content) ? (
|
||||||
<RenderMultiModal content={content} />
|
<RenderMultiModal content={content} thumbnail />
|
||||||
|
) : messageUtils.isNestedMessageContent(content) ? (
|
||||||
|
<RenderNestedMessages content={content} />
|
||||||
) : messageUtils.isFunctionExecutionResult(content) ? (
|
) : messageUtils.isFunctionExecutionResult(content) ? (
|
||||||
<RenderToolResult content={content} />
|
<RenderToolResult content={content} />
|
||||||
) : message.source === "llm_call_event" ? (
|
) : message.source === "llm_call_event" ? (
|
||||||
|
@ -198,7 +261,6 @@ export const RenderMessage: React.FC<MessageProps> = ({
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{message.models_usage && (
|
{message.models_usage && (
|
||||||
<div className="text-xs text-secondary mt-1">
|
<div className="text-xs text-secondary mt-1">
|
||||||
Tokens:{" "}
|
Tokens:{" "}
|
||||||
|
|
|
@ -113,6 +113,8 @@ const RunView: React.FC<RunViewProps> = ({
|
||||||
return run.messages.filter((msg) => msg.config.source !== "llm_call_event");
|
return run.messages.filter((msg) => msg.config.source !== "llm_call_event");
|
||||||
}, [run.messages, uiSettings.show_llm_call_events]);
|
}, [run.messages, uiSettings.show_llm_call_events]);
|
||||||
|
|
||||||
|
console.log("Run task", run.task);
|
||||||
|
|
||||||
// Replace existing scroll effect with this simpler one
|
// Replace existing scroll effect with this simpler one
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
|
|
|
@ -33,7 +33,7 @@ import {
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { useTeamBuilderStore } from "./store";
|
import { useTeamBuilderStore } from "./store";
|
||||||
import { ComponentLibrary } from "./library";
|
import { ComponentLibrary } from "./library";
|
||||||
import { ComponentTypes, Team } from "../../../types/datamodel";
|
import { ComponentTypes, Gallery, Team } from "../../../types/datamodel";
|
||||||
import { CustomNode, CustomEdge, DragItem } from "./types";
|
import { CustomNode, CustomEdge, DragItem } from "./types";
|
||||||
import { edgeTypes, nodeTypes } from "./nodes";
|
import { edgeTypes, nodeTypes } from "./nodes";
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ import TestDrawer from "./testdrawer";
|
||||||
import { validationAPI, ValidationResponse } from "../api";
|
import { validationAPI, ValidationResponse } from "../api";
|
||||||
import { ValidationErrors } from "./validationerrors";
|
import { ValidationErrors } from "./validationerrors";
|
||||||
import ComponentEditor from "./component-editor/component-editor";
|
import ComponentEditor from "./component-editor/component-editor";
|
||||||
import { useGalleryStore } from "../../gallery/store";
|
// import { useGalleryStore } from "../../gallery/store";
|
||||||
|
|
||||||
const { Sider, Content } = Layout;
|
const { Sider, Content } = Layout;
|
||||||
interface DragItemData {
|
interface DragItemData {
|
||||||
|
@ -60,12 +60,14 @@ interface TeamBuilderProps {
|
||||||
team: Team;
|
team: Team;
|
||||||
onChange?: (team: Partial<Team>) => void;
|
onChange?: (team: Partial<Team>) => void;
|
||||||
onDirtyStateChange?: (isDirty: boolean) => void;
|
onDirtyStateChange?: (isDirty: boolean) => void;
|
||||||
|
selectedGallery?: Gallery | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const TeamBuilder: React.FC<TeamBuilderProps> = ({
|
export const TeamBuilder: React.FC<TeamBuilderProps> = ({
|
||||||
team,
|
team,
|
||||||
onChange,
|
onChange,
|
||||||
onDirtyStateChange,
|
onDirtyStateChange,
|
||||||
|
selectedGallery,
|
||||||
}) => {
|
}) => {
|
||||||
// Replace store state with React Flow hooks
|
// Replace store state with React Flow hooks
|
||||||
const [nodes, setNodes, onNodesChange] = useNodesState<CustomNode>([]);
|
const [nodes, setNodes, onNodesChange] = useNodesState<CustomNode>([]);
|
||||||
|
@ -86,7 +88,7 @@ export const TeamBuilder: React.FC<TeamBuilderProps> = ({
|
||||||
const [validationLoading, setValidationLoading] = useState(false);
|
const [validationLoading, setValidationLoading] = useState(false);
|
||||||
|
|
||||||
const [testDrawerVisible, setTestDrawerVisible] = useState(false);
|
const [testDrawerVisible, setTestDrawerVisible] = useState(false);
|
||||||
const defaultGallery = useGalleryStore((state) => state.getSelectedGallery());
|
|
||||||
const {
|
const {
|
||||||
undo,
|
undo,
|
||||||
redo,
|
redo,
|
||||||
|
@ -465,8 +467,8 @@ export const TeamBuilder: React.FC<TeamBuilderProps> = ({
|
||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
>
|
>
|
||||||
<Layout className=" relative bg-primary h-[calc(100vh-239px)] rounded">
|
<Layout className=" relative bg-primary h-[calc(100vh-239px)] rounded">
|
||||||
{!isJsonMode && defaultGallery && (
|
{!isJsonMode && selectedGallery && (
|
||||||
<ComponentLibrary defaultGallery={defaultGallery} />
|
<ComponentLibrary defaultGallery={selectedGallery} />
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<Layout className="bg-primary rounded">
|
<Layout className="bg-primary rounded">
|
||||||
|
|
|
@ -5,7 +5,7 @@ import { appContext } from "../../../hooks/provider";
|
||||||
import { teamAPI } from "./api";
|
import { teamAPI } from "./api";
|
||||||
import { useGalleryStore } from "../gallery/store";
|
import { useGalleryStore } from "../gallery/store";
|
||||||
import { TeamSidebar } from "./sidebar";
|
import { TeamSidebar } from "./sidebar";
|
||||||
import type { Team } from "../../types/datamodel";
|
import { Gallery, type Team } from "../../types/datamodel";
|
||||||
import { TeamBuilder } from "./builder/builder";
|
import { TeamBuilder } from "./builder/builder";
|
||||||
|
|
||||||
export const TeamManager: React.FC = () => {
|
export const TeamManager: React.FC = () => {
|
||||||
|
@ -19,18 +19,12 @@ export const TeamManager: React.FC = () => {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const [selectedGallery, setSelectedGallery] = useState<Gallery | null>(null);
|
||||||
|
|
||||||
const { user } = useContext(appContext);
|
const { user } = useContext(appContext);
|
||||||
const [messageApi, contextHolder] = message.useMessage();
|
const [messageApi, contextHolder] = message.useMessage();
|
||||||
const [hasUnsavedChanges, setHasUnsavedChanges] = useState(false);
|
const [hasUnsavedChanges, setHasUnsavedChanges] = useState(false);
|
||||||
|
|
||||||
// Initialize galleries
|
|
||||||
const fetchGalleries = useGalleryStore((state) => state.fetchGalleries);
|
|
||||||
useEffect(() => {
|
|
||||||
if (user?.id) {
|
|
||||||
fetchGalleries(user.id);
|
|
||||||
}
|
|
||||||
}, [user?.id, fetchGalleries]);
|
|
||||||
|
|
||||||
// Persist sidebar state
|
// Persist sidebar state
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (typeof window !== "undefined") {
|
if (typeof window !== "undefined") {
|
||||||
|
@ -171,6 +165,8 @@ export const TeamManager: React.FC = () => {
|
||||||
onEditTeam={setCurrentTeam}
|
onEditTeam={setCurrentTeam}
|
||||||
onDeleteTeam={handleDeleteTeam}
|
onDeleteTeam={handleDeleteTeam}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
|
setSelectedGallery={setSelectedGallery}
|
||||||
|
selectedGallery={selectedGallery}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
@ -205,6 +201,7 @@ export const TeamManager: React.FC = () => {
|
||||||
team={currentTeam}
|
team={currentTeam}
|
||||||
onChange={handleSaveTeam}
|
onChange={handleSaveTeam}
|
||||||
onDirtyStateChange={setHasUnsavedChanges}
|
onDirtyStateChange={setHasUnsavedChanges}
|
||||||
|
selectedGallery={selectedGallery}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<div className="flex items-center justify-center h-[calc(100vh-190px)] text-secondary">
|
<div className="flex items-center justify-center h-[calc(100vh-190px)] text-secondary">
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import React, { useState } from "react";
|
import React, { useContext, useState } from "react";
|
||||||
import { Button, Tooltip, Select, message } from "antd";
|
import { Button, Tooltip, Select, message } from "antd";
|
||||||
import {
|
import {
|
||||||
Bot,
|
Bot,
|
||||||
|
@ -12,9 +12,12 @@ import {
|
||||||
RefreshCcw,
|
RefreshCcw,
|
||||||
History,
|
History,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import type { Team } from "../../types/datamodel";
|
import type { Gallery, Team } from "../../types/datamodel";
|
||||||
import { getRelativeTimeString } from "../atoms";
|
import { getRelativeTimeString } from "../atoms";
|
||||||
import { useGalleryStore } from "../gallery/store";
|
import { GalleryAPI } from "../gallery/api";
|
||||||
|
import { appContext } from "../../../hooks/provider";
|
||||||
|
import { Link } from "gatsby";
|
||||||
|
import { getLocalStorage, setLocalStorage } from "../../utils/utils";
|
||||||
|
|
||||||
interface TeamSidebarProps {
|
interface TeamSidebarProps {
|
||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
|
@ -26,6 +29,8 @@ interface TeamSidebarProps {
|
||||||
onEditTeam: (team: Team) => void;
|
onEditTeam: (team: Team) => void;
|
||||||
onDeleteTeam: (teamId: number) => void;
|
onDeleteTeam: (teamId: number) => void;
|
||||||
isLoading?: boolean;
|
isLoading?: boolean;
|
||||||
|
selectedGallery: Gallery | null;
|
||||||
|
setSelectedGallery: (gallery: Gallery) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const TeamSidebar: React.FC<TeamSidebarProps> = ({
|
export const TeamSidebar: React.FC<TeamSidebarProps> = ({
|
||||||
|
@ -38,18 +43,50 @@ export const TeamSidebar: React.FC<TeamSidebarProps> = ({
|
||||||
onEditTeam,
|
onEditTeam,
|
||||||
onDeleteTeam,
|
onDeleteTeam,
|
||||||
isLoading = false,
|
isLoading = false,
|
||||||
|
selectedGallery,
|
||||||
|
setSelectedGallery,
|
||||||
}) => {
|
}) => {
|
||||||
// Tab state - "recent" or "gallery"
|
// Tab state - "recent" or "gallery"
|
||||||
const [activeTab, setActiveTab] = useState<"recent" | "gallery">("recent");
|
const [activeTab, setActiveTab] = useState<"recent" | "gallery">("recent");
|
||||||
const [messageApi, contextHolder] = message.useMessage();
|
const [messageApi, contextHolder] = message.useMessage();
|
||||||
|
|
||||||
// Gallery store
|
const [isLoadingGalleries, setIsLoadingGalleries] = useState(false);
|
||||||
const {
|
const [galleries, setGalleries] = useState<Gallery[]>([]);
|
||||||
galleries,
|
const { user } = useContext(appContext);
|
||||||
selectedGallery,
|
|
||||||
selectGallery,
|
// Fetch galleries
|
||||||
isLoading: isLoadingGalleries,
|
|
||||||
} = useGalleryStore();
|
const fetchGalleries = async () => {
|
||||||
|
if (!user?.id) return;
|
||||||
|
setIsLoadingGalleries(true);
|
||||||
|
try {
|
||||||
|
const galleryAPI = new GalleryAPI();
|
||||||
|
const data = await galleryAPI.listGalleries(user.id);
|
||||||
|
setGalleries(data);
|
||||||
|
|
||||||
|
// Check localStorage for a previously saved gallery ID
|
||||||
|
const savedGalleryId = getLocalStorage(`selectedGalleryId_${user.id}`);
|
||||||
|
|
||||||
|
if (savedGalleryId && data.length > 0) {
|
||||||
|
const savedGallery = data.find((g) => g.id === savedGalleryId);
|
||||||
|
if (savedGallery) {
|
||||||
|
setSelectedGallery(savedGallery);
|
||||||
|
} else if (!selectedGallery && data.length > 0) {
|
||||||
|
setSelectedGallery(data[0]);
|
||||||
|
}
|
||||||
|
} else if (!selectedGallery && data.length > 0) {
|
||||||
|
setSelectedGallery(data[0]);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error fetching galleries:", error);
|
||||||
|
} finally {
|
||||||
|
setIsLoadingGalleries(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Fetch galleries on mount
|
||||||
|
React.useEffect(() => {
|
||||||
|
fetchGalleries();
|
||||||
|
}, [user?.id]);
|
||||||
|
|
||||||
// Render collapsed state
|
// Render collapsed state
|
||||||
if (!isOpen) {
|
if (!isOpen) {
|
||||||
|
@ -262,13 +299,28 @@ export const TeamSidebar: React.FC<TeamSidebarProps> = ({
|
||||||
{activeTab === "gallery" && (
|
{activeTab === "gallery" && (
|
||||||
<div className="p-2">
|
<div className="p-2">
|
||||||
{/* Gallery Selector */}
|
{/* Gallery Selector */}
|
||||||
|
<div className="my-2 mb-3 text-xs">
|
||||||
|
{" "}
|
||||||
|
Select a{" "}
|
||||||
|
<Link to="/gallery" className="text-accent">
|
||||||
|
<span className="font-medium">gallery</span>
|
||||||
|
</Link>{" "}
|
||||||
|
to view its components as templates
|
||||||
|
</div>
|
||||||
<Select
|
<Select
|
||||||
className="w-full mb-4"
|
className="w-full mb-4"
|
||||||
placeholder="Select gallery"
|
placeholder="Select gallery"
|
||||||
value={selectedGallery?.id}
|
value={selectedGallery?.id}
|
||||||
onChange={(value) => {
|
onChange={(value) => {
|
||||||
const gallery = galleries.find((g) => g.id === value);
|
const gallery = galleries.find((g) => g.id === value);
|
||||||
if (gallery) selectGallery(gallery);
|
if (gallery) {
|
||||||
|
setSelectedGallery(gallery);
|
||||||
|
|
||||||
|
// Save the selected gallery ID to localStorage
|
||||||
|
if (user?.id) {
|
||||||
|
setLocalStorage(`selectedGalleryId_${user.id}`, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
options={galleries.map((gallery) => ({
|
options={galleries.map((gallery) => ({
|
||||||
value: gallery.id,
|
value: gallery.id,
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
import * as React from "react";
|
||||||
|
import Layout from "../components/layout";
|
||||||
|
import { graphql } from "gatsby";
|
||||||
|
import DeployManager from "../components/views/deploy/manager";
|
||||||
|
import LabsManager from "../components/views/labs/manager";
|
||||||
|
|
||||||
|
// markup
|
||||||
|
const LabsPage = ({ data }: any) => {
|
||||||
|
return (
|
||||||
|
<Layout meta={data.site.siteMetadata} title="Home" link={"/labs"}>
|
||||||
|
<main style={{ height: "100%" }} className=" h-full ">
|
||||||
|
<LabsManager />
|
||||||
|
</main>
|
||||||
|
</Layout>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const query = graphql`
|
||||||
|
query HomePageQuery {
|
||||||
|
site {
|
||||||
|
siteMetadata {
|
||||||
|
description
|
||||||
|
title
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
export default LabsPage;
|
|
@ -381,3 +381,12 @@ div#gatsby-focus-wrapper {
|
||||||
height: 100%;
|
height: 100%;
|
||||||
/* border: 1px solid green; */
|
/* border: 1px solid green; */
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.zero-padding-upload.ant-upload,
|
||||||
|
.zero-padding-upload .ant-upload,
|
||||||
|
.zero-padding-upload .ant-upload-select,
|
||||||
|
.zero-padding-upload .ant-btn {
|
||||||
|
padding: 0 !important;
|
||||||
|
margin: 0 !important;
|
||||||
|
border: none !important;
|
||||||
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ class TestDatabaseOperations:
|
||||||
def test_basic_setup(self, test_db: DatabaseManager):
|
def test_basic_setup(self, test_db: DatabaseManager):
|
||||||
"""Test basic database setup and connection"""
|
"""Test basic database setup and connection"""
|
||||||
with Session(test_db.engine) as session:
|
with Session(test_db.engine) as session:
|
||||||
result = session.exec(text("SELECT 1")).first()
|
result = session.exec(text("SELECT 1")).first() # type: ignore
|
||||||
assert result[0] == 1
|
assert result[0] == 1
|
||||||
result = session.exec(select(1)).first()
|
result = session.exec(select(1)).first()
|
||||||
assert result == 1
|
assert result == 1
|
||||||
|
@ -85,7 +85,7 @@ class TestDatabaseOperations:
|
||||||
# Verify Update
|
# Verify Update
|
||||||
result = test_db.get(Team, {"id": team_id})
|
result = test_db.get(Team, {"id": team_id})
|
||||||
assert result.status is True
|
assert result.status is True
|
||||||
assert result.data[0].version == "0.0.2"
|
assert result.data and result.data[0].version == "0.0.2"
|
||||||
|
|
||||||
def test_delete_operations(self, test_db: DatabaseManager, sample_team: Team):
|
def test_delete_operations(self, test_db: DatabaseManager, sample_team: Team):
|
||||||
"""Test delete with various filters"""
|
"""Test delete with various filters"""
|
||||||
|
@ -103,6 +103,7 @@ class TestDatabaseOperations:
|
||||||
|
|
||||||
# Verify deletion
|
# Verify deletion
|
||||||
result = test_db.get(Team, {"id": team_id})
|
result = test_db.get(Team, {"id": team_id})
|
||||||
|
if result.data:
|
||||||
assert len(result.data) == 0
|
assert len(result.data) == 0
|
||||||
|
|
||||||
def test_cascade_delete(self, test_db: DatabaseManager, test_user: str):
|
def test_cascade_delete(self, test_db: DatabaseManager, test_user: str):
|
||||||
|
@ -133,7 +134,9 @@ class TestDatabaseOperations:
|
||||||
))
|
))
|
||||||
|
|
||||||
test_db.delete(Run, {"id": run1_id})
|
test_db.delete(Run, {"id": run1_id})
|
||||||
assert len(test_db.get(Message, {"run_id": run1_id}).data) == 0, "Run->Message cascade failed"
|
db_message = test_db.get(Message, {"run_id": run1_id})
|
||||||
|
if db_message.data:
|
||||||
|
assert len(db_message.data) == 0, "Run->Message cascade failed"
|
||||||
|
|
||||||
# Test Session -> Run -> Message cascade
|
# Test Session -> Run -> Message cascade
|
||||||
session2 = SessionModel(user_id=test_user, team_id=team1.id, name="Session2")
|
session2 = SessionModel(user_id=test_user, team_id=team1.id, name="Session2")
|
||||||
|
@ -154,8 +157,12 @@ class TestDatabaseOperations:
|
||||||
))
|
))
|
||||||
|
|
||||||
test_db.delete(SessionModel, {"id": session2.id})
|
test_db.delete(SessionModel, {"id": session2.id})
|
||||||
assert len(test_db.get(Run, {"session_id": session2.id}).data) == 0, "Session->Run cascade failed"
|
session = test_db.get(SessionModel, {"id": session2.id})
|
||||||
assert len(test_db.get(Message, {"run_id": run2_id}).data) == 0, "Session->Run->Message cascade failed"
|
run = test_db.get(Run, {"id": run2_id})
|
||||||
|
if session.data:
|
||||||
|
assert len(session.data) == 0, "Session->Run cascade failed"
|
||||||
|
if run.data:
|
||||||
|
assert len(run.data) == 0, "Session->Run->Message cascade failed"
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
test_db.delete(Team, {"id": team1.id})
|
test_db.delete(Team, {"id": team1.id})
|
||||||
|
|
|
@ -808,7 +808,7 @@ requires-dist = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autogenstudio"
|
name = "autogenstudio"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
source = { editable = "packages/autogen-studio" }
|
source = { editable = "packages/autogen-studio" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
|
|
Loading…
Reference in New Issue