mirror of https://github.com/microsoft/autogen.git
246 lines
8.2 KiB
Python
246 lines
8.2 KiB
Python
import os
|
|
import asyncio
|
|
import pytest
|
|
from sqlmodel import Session, text, select
|
|
from typing import Generator
|
|
|
|
from autogenstudio.database import DatabaseManager
|
|
from autogenstudio.datamodel.types import (
|
|
ToolConfig,
|
|
OpenAIModelConfig,
|
|
RoundRobinTeamConfig,
|
|
StopMessageTerminationConfig,
|
|
AssistantAgentConfig,
|
|
ModelTypes, AgentTypes, TeamTypes, ComponentTypes,
|
|
TerminationTypes, ToolTypes
|
|
)
|
|
from autogenstudio.datamodel.db import Model, Tool, Agent, Team, LinkTypes
|
|
|
|
|
|
@pytest.fixture
|
|
def test_db() -> Generator[DatabaseManager, None, None]:
|
|
"""Fixture for test database"""
|
|
db_path = "test.db"
|
|
db = DatabaseManager(f"sqlite:///{db_path}")
|
|
db.reset_db()
|
|
# Initialize database instead of create_db_and_tables
|
|
db.initialize_database(auto_upgrade=False)
|
|
yield db
|
|
# Clean up
|
|
asyncio.run(db.close())
|
|
db.reset_db()
|
|
try:
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|
|
except Exception as e:
|
|
print(f"Warning: Failed to remove test database file: {e}")
|
|
|
|
|
|
@pytest.fixture
|
|
def test_user() -> str:
|
|
return "test_user@example.com"
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_model(test_user: str) -> Model:
|
|
"""Create a sample model with proper config"""
|
|
return Model(
|
|
user_id=test_user,
|
|
config=OpenAIModelConfig(
|
|
model="gpt-4",
|
|
model_type=ModelTypes.OPENAI,
|
|
component_type=ComponentTypes.MODEL,
|
|
version="1.0.0"
|
|
).model_dump()
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_tool(test_user: str) -> Tool:
|
|
"""Create a sample tool with proper config"""
|
|
return Tool(
|
|
user_id=test_user,
|
|
config=ToolConfig(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
content="async def test_func(x: str) -> str:\n return f'Test {x}'",
|
|
tool_type=ToolTypes.PYTHON_FUNCTION,
|
|
component_type=ComponentTypes.TOOL,
|
|
version="1.0.0"
|
|
).model_dump()
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_agent(test_user: str, sample_model: Model, sample_tool: Tool) -> Agent:
|
|
"""Create a sample agent with proper config and relationships"""
|
|
return Agent(
|
|
user_id=test_user,
|
|
config=AssistantAgentConfig(
|
|
name="test_agent",
|
|
agent_type=AgentTypes.ASSISTANT,
|
|
model_client=OpenAIModelConfig.model_validate(sample_model.config),
|
|
tools=[ToolConfig.model_validate(sample_tool.config)],
|
|
component_type=ComponentTypes.AGENT,
|
|
version="1.0.0"
|
|
).model_dump()
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_team(test_user: str, sample_agent: Agent) -> Team:
|
|
"""Create a sample team with proper config"""
|
|
return Team(
|
|
user_id=test_user,
|
|
config=RoundRobinTeamConfig(
|
|
name="test_team",
|
|
participants=[AssistantAgentConfig.model_validate(
|
|
sample_agent.config)],
|
|
termination_condition=StopMessageTerminationConfig(
|
|
termination_type=TerminationTypes.STOP_MESSAGE,
|
|
component_type=ComponentTypes.TERMINATION,
|
|
version="1.0.0"
|
|
).model_dump(),
|
|
team_type=TeamTypes.ROUND_ROBIN,
|
|
component_type=ComponentTypes.TEAM,
|
|
version="1.0.0"
|
|
).model_dump()
|
|
)
|
|
|
|
|
|
class TestDatabaseOperations:
|
|
def test_basic_setup(self, test_db: DatabaseManager):
|
|
"""Test basic database setup and connection"""
|
|
with Session(test_db.engine) as session:
|
|
result = session.exec(text("SELECT 1")).first()
|
|
assert result[0] == 1
|
|
result = session.exec(select(1)).first()
|
|
assert result == 1
|
|
|
|
def test_basic_entity_creation(self, test_db: DatabaseManager, sample_model: Model,
|
|
sample_tool: Tool, sample_agent: Agent, sample_team: Team):
|
|
"""Test creating all entity types with proper configs"""
|
|
with Session(test_db.engine) as session:
|
|
# Add all entities
|
|
session.add(sample_model)
|
|
session.add(sample_tool)
|
|
session.add(sample_agent)
|
|
session.add(sample_team)
|
|
session.commit()
|
|
|
|
# Store IDs
|
|
model_id = sample_model.id
|
|
tool_id = sample_tool.id
|
|
agent_id = sample_agent.id
|
|
team_id = sample_team.id
|
|
|
|
# Verify all entities were created with new session
|
|
with Session(test_db.engine) as session:
|
|
assert session.get(Model, model_id) is not None
|
|
assert session.get(Tool, tool_id) is not None
|
|
assert session.get(Agent, agent_id) is not None
|
|
assert session.get(Team, team_id) is not None
|
|
|
|
def test_multiple_links(self, test_db: DatabaseManager, sample_agent: Agent):
|
|
"""Test linking multiple models to an agent"""
|
|
with Session(test_db.engine) as session:
|
|
# Create two models with updated configs
|
|
model1 = Model(
|
|
user_id="test_user",
|
|
config=OpenAIModelConfig(
|
|
model="gpt-4",
|
|
model_type=ModelTypes.OPENAI,
|
|
component_type=ComponentTypes.MODEL,
|
|
version="1.0.0"
|
|
).model_dump()
|
|
)
|
|
model2 = Model(
|
|
user_id="test_user",
|
|
config=OpenAIModelConfig(
|
|
model="gpt-3.5",
|
|
model_type=ModelTypes.OPENAI,
|
|
component_type=ComponentTypes.MODEL,
|
|
version="1.0.0"
|
|
).model_dump()
|
|
)
|
|
|
|
# Add and commit all entities
|
|
session.add(model1)
|
|
session.add(model2)
|
|
session.add(sample_agent)
|
|
session.commit()
|
|
|
|
model1_id = model1.id
|
|
model2_id = model2.id
|
|
agent_id = sample_agent.id
|
|
|
|
# Create links using IDs
|
|
test_db.link(LinkTypes.AGENT_MODEL, agent_id, model1_id)
|
|
test_db.link(LinkTypes.AGENT_MODEL, agent_id, model2_id)
|
|
|
|
# Verify links
|
|
linked_models = test_db.get_linked_entities(
|
|
LinkTypes.AGENT_MODEL, agent_id)
|
|
assert len(linked_models.data) == 2
|
|
|
|
# Verify model names
|
|
model_names = [model.config["model"] for model in linked_models.data]
|
|
assert "gpt-4" in model_names
|
|
assert "gpt-3.5" in model_names
|
|
|
|
def test_upsert_operations(self, test_db: DatabaseManager, sample_model: Model):
|
|
"""Test upsert for both create and update scenarios"""
|
|
# Test Create
|
|
response = test_db.upsert(sample_model)
|
|
assert response.status is True
|
|
assert "Created Successfully" in response.message
|
|
|
|
# Test Update
|
|
sample_model.config["model"] = "gpt-4-turbo"
|
|
response = test_db.upsert(sample_model)
|
|
assert response.status is True
|
|
assert "Updated Successfully" in response.message
|
|
|
|
# Verify Update
|
|
result = test_db.get(Model, {"id": sample_model.id})
|
|
assert result.status is True
|
|
assert result.data[0].config["model"] == "gpt-4-turbo"
|
|
|
|
def test_delete_operations(self, test_db: DatabaseManager, sample_model: Model):
|
|
"""Test delete with various filters"""
|
|
# First insert the model
|
|
test_db.upsert(sample_model)
|
|
|
|
# Test deletion by id
|
|
response = test_db.delete(Model, {"id": sample_model.id})
|
|
assert response.status is True
|
|
assert "Deleted Successfully" in response.message
|
|
|
|
# Verify deletion
|
|
result = test_db.get(Model, {"id": sample_model.id})
|
|
assert len(result.data) == 0
|
|
|
|
# Test deletion with non-existent id
|
|
response = test_db.delete(Model, {"id": 999999})
|
|
assert "Row not found" in response.message
|
|
|
|
def test_initialize_database_scenarios(self):
|
|
"""Test different initialize_database parameters"""
|
|
db_path = "test_init.db"
|
|
db = DatabaseManager(f"sqlite:///{db_path}")
|
|
|
|
try:
|
|
# Test basic initialization
|
|
response = db.initialize_database()
|
|
assert response.status is True
|
|
|
|
# Test with auto_upgrade
|
|
response = db.initialize_database(auto_upgrade=True)
|
|
assert response.status is True
|
|
|
|
finally:
|
|
asyncio.run(db.close())
|
|
db.reset_db()
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|