mirror of https://github.com/microsoft/autogen.git
407 lines
16 KiB
Python
407 lines
16 KiB
Python
import threading
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from loguru import logger
|
|
from sqlalchemy import exc, func, inspect, text
|
|
from sqlmodel import Session, SQLModel, and_, create_engine, select
|
|
|
|
from ..datamodel import LinkTypes, Response
|
|
from .schema_manager import SchemaManager
|
|
|
|
# from .dbutils import init_db_samples
|
|
|
|
|
|
class DatabaseManager:
|
|
_init_lock = threading.Lock()
|
|
|
|
def __init__(self, engine_uri: str, base_dir: Optional[Path] = None):
|
|
"""
|
|
Initialize DatabaseManager with database connection settings.
|
|
Does not perform any database operations.
|
|
|
|
Args:
|
|
engine_uri: Database connection URI (e.g. sqlite:///db.sqlite3)
|
|
base_dir: Base directory for migration files. If None, uses current directory
|
|
"""
|
|
connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
|
|
|
|
self.engine = create_engine(engine_uri, connect_args=connection_args)
|
|
self.schema_manager = SchemaManager(
|
|
engine=self.engine,
|
|
base_dir=base_dir,
|
|
)
|
|
|
|
def initialize_database(self, auto_upgrade: bool = False, force_init_alembic: bool = True) -> Response:
|
|
"""
|
|
Initialize database and migrations in the correct order.
|
|
|
|
Args:
|
|
auto_upgrade: If True, automatically generate and apply migrations for schema changes
|
|
force_init_alembic: If True, reinitialize alembic configuration even if it exists
|
|
"""
|
|
if not self._init_lock.acquire(blocking=False):
|
|
return Response(message="Database initialization already in progress", status=False)
|
|
|
|
try:
|
|
inspector = inspect(self.engine)
|
|
tables_exist = inspector.get_table_names()
|
|
|
|
if not tables_exist:
|
|
# Fresh install - create tables and initialize migrations
|
|
logger.info("Creating database tables...")
|
|
SQLModel.metadata.create_all(self.engine)
|
|
|
|
if self.schema_manager.initialize_migrations(force=force_init_alembic):
|
|
return Response(message="Database initialized successfully", status=True)
|
|
return Response(message="Failed to initialize migrations", status=False)
|
|
|
|
# Handle existing database
|
|
if auto_upgrade:
|
|
logger.info("Checking database schema...")
|
|
if self.schema_manager.ensure_schema_up_to_date(): # <-- Use this instead
|
|
return Response(message="Database schema is up to date", status=True)
|
|
return Response(message="Database upgrade failed", status=False)
|
|
|
|
return Response(message="Database is ready", status=True)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Database initialization failed: {str(e)}"
|
|
logger.error(error_msg)
|
|
return Response(message=error_msg, status=False)
|
|
finally:
|
|
self._init_lock.release()
|
|
|
|
def reset_db(self, recreate_tables: bool = True):
|
|
"""
|
|
Reset the database by dropping all tables and optionally recreating them.
|
|
|
|
Args:
|
|
recreate_tables (bool): If True, recreates the tables after dropping them.
|
|
Set to False if you want to call create_db_and_tables() separately.
|
|
"""
|
|
if not self._init_lock.acquire(blocking=False):
|
|
logger.warning("Database reset already in progress")
|
|
return Response(message="Database reset already in progress", status=False, data=None)
|
|
|
|
try:
|
|
# Dispose existing connections
|
|
self.engine.dispose()
|
|
with Session(self.engine) as session:
|
|
try:
|
|
# Disable foreign key checks for SQLite
|
|
if "sqlite" in str(self.engine.url):
|
|
session.exec(text("PRAGMA foreign_keys=OFF"))
|
|
|
|
# Drop all tables
|
|
SQLModel.metadata.drop_all(self.engine)
|
|
logger.info("All tables dropped successfully")
|
|
|
|
# Re-enable foreign key checks for SQLite
|
|
if "sqlite" in str(self.engine.url):
|
|
session.exec(text("PRAGMA foreign_keys=ON"))
|
|
|
|
session.commit()
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
self._init_lock.release()
|
|
|
|
if recreate_tables:
|
|
logger.info("Recreating tables...")
|
|
self.initialize_database(auto_upgrade=False, force_init_alembic=True)
|
|
|
|
return Response(
|
|
message="Database reset successfully" if recreate_tables else "Database tables dropped successfully",
|
|
status=True,
|
|
data=None,
|
|
)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error while resetting database: {str(e)}"
|
|
logger.error(error_msg)
|
|
return Response(message=error_msg, status=False, data=None)
|
|
finally:
|
|
if self._init_lock.locked():
|
|
self._init_lock.release()
|
|
logger.info("Database reset lock released")
|
|
|
|
def upsert(self, model: SQLModel, return_json: bool = True):
|
|
"""Create or update an entity
|
|
|
|
Args:
|
|
model (SQLModel): The model instance to create or update
|
|
return_json (bool, optional): If True, returns the model as a dictionary.
|
|
If False, returns the SQLModel instance. Defaults to True.
|
|
|
|
Returns:
|
|
Response: Contains status, message and data (either dict or SQLModel based on return_json)
|
|
"""
|
|
status = True
|
|
model_class = type(model)
|
|
existing_model = None
|
|
|
|
with Session(self.engine) as session:
|
|
try:
|
|
existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
|
|
if existing_model:
|
|
model.updated_at = datetime.now()
|
|
for key, value in model.model_dump().items():
|
|
setattr(existing_model, key, value)
|
|
model = existing_model # Use the updated existing model
|
|
session.add(model)
|
|
else:
|
|
session.add(model)
|
|
session.commit()
|
|
session.refresh(model)
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error("Error while updating/creating " + str(model_class.__name__) + ": " + str(e))
|
|
status = False
|
|
|
|
return Response(
|
|
message=(
|
|
f"{model_class.__name__} Updated Successfully"
|
|
if existing_model
|
|
else f"{model_class.__name__} Created Successfully"
|
|
),
|
|
status=status,
|
|
data=model.model_dump() if return_json else model,
|
|
)
|
|
|
|
def _model_to_dict(self, model_obj):
|
|
return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
|
|
|
|
def get(
|
|
self,
|
|
model_class: SQLModel,
|
|
filters: dict = None,
|
|
return_json: bool = False,
|
|
order: str = "desc",
|
|
):
|
|
"""List entities"""
|
|
with Session(self.engine) as session:
|
|
result = []
|
|
status = True
|
|
status_message = ""
|
|
|
|
try:
|
|
statement = select(model_class)
|
|
if filters:
|
|
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
|
statement = statement.where(and_(*conditions))
|
|
|
|
if hasattr(model_class, "created_at") and order:
|
|
order_by_clause = getattr(model_class.created_at, order)() # Dynamically apply asc/desc
|
|
statement = statement.order_by(order_by_clause)
|
|
|
|
items = session.exec(statement).all()
|
|
result = [self._model_to_dict(item) if return_json else item for item in items]
|
|
status_message = f"{model_class.__name__} Retrieved Successfully"
|
|
except Exception as e:
|
|
session.rollback()
|
|
status = False
|
|
status_message = f"Error while fetching {model_class.__name__}"
|
|
logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
|
|
|
|
return Response(message=status_message, status=status, data=result)
|
|
|
|
def delete(self, model_class: SQLModel, filters: dict = None):
|
|
"""Delete an entity"""
|
|
status_message = ""
|
|
status = True
|
|
|
|
with Session(self.engine) as session:
|
|
try:
|
|
statement = select(model_class)
|
|
if filters:
|
|
conditions = [getattr(model_class, col) == value for col, value in filters.items()]
|
|
statement = statement.where(and_(*conditions))
|
|
|
|
rows = session.exec(statement).all()
|
|
|
|
if rows:
|
|
for row in rows:
|
|
session.delete(row)
|
|
session.commit()
|
|
status_message = f"{model_class.__name__} Deleted Successfully"
|
|
else:
|
|
status_message = "Row not found"
|
|
logger.info(f"Row with filters {filters} not found")
|
|
|
|
except exc.IntegrityError as e:
|
|
session.rollback()
|
|
status = False
|
|
status_message = f"Integrity error: The {model_class.__name__} is linked to another entity and cannot be deleted. {e}"
|
|
# Log the specific integrity error
|
|
logger.error(status_message)
|
|
except Exception as e:
|
|
session.rollback()
|
|
status = False
|
|
status_message = f"Error while deleting: {e}"
|
|
logger.error(status_message)
|
|
|
|
return Response(message=status_message, status=status, data=None)
|
|
|
|
def link(
|
|
self,
|
|
link_type: LinkTypes,
|
|
primary_id: int,
|
|
secondary_id: int,
|
|
sequence: Optional[int] = None,
|
|
):
|
|
"""Link two entities with automatic sequence handling."""
|
|
with Session(self.engine) as session:
|
|
try:
|
|
# Get classes from LinkTypes
|
|
primary_class = link_type.primary_class
|
|
secondary_class = link_type.secondary_class
|
|
link_table = link_type.link_table
|
|
|
|
# Get entities
|
|
primary_entity = session.get(primary_class, primary_id)
|
|
secondary_entity = session.get(secondary_class, secondary_id)
|
|
|
|
if not primary_entity or not secondary_entity:
|
|
return Response(message="One or both entities do not exist", status=False)
|
|
|
|
# Get field names
|
|
primary_id_field = f"{primary_class.__name__.lower()}_id"
|
|
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
|
|
|
|
# Check for existing link
|
|
existing_link = session.exec(
|
|
select(link_table).where(
|
|
and_(
|
|
getattr(link_table, primary_id_field) == primary_id,
|
|
getattr(link_table, secondary_id_field) == secondary_id,
|
|
)
|
|
)
|
|
).first()
|
|
|
|
if existing_link:
|
|
return Response(message="Link already exists", status=False)
|
|
|
|
# Get the next sequence number if not provided
|
|
if sequence is None:
|
|
max_seq_result = session.exec(
|
|
select(func.max(link_table.sequence)).where(getattr(link_table, primary_id_field) == primary_id)
|
|
).first()
|
|
sequence = 0 if max_seq_result is None else max_seq_result + 1
|
|
|
|
# Create new link
|
|
new_link = link_table(
|
|
**{primary_id_field: primary_id, secondary_id_field: secondary_id, "sequence": sequence}
|
|
)
|
|
session.add(new_link)
|
|
session.commit()
|
|
|
|
return Response(message=f"Entities linked successfully with sequence {sequence}", status=True)
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
return Response(message=f"Error linking entities: {str(e)}", status=False)
|
|
|
|
def unlink(self, link_type: LinkTypes, primary_id: int, secondary_id: int, sequence: Optional[int] = None):
|
|
"""Unlink two entities and reorder sequences if needed."""
|
|
with Session(self.engine) as session:
|
|
try:
|
|
# Get classes from LinkTypes
|
|
primary_class = link_type.primary_class
|
|
secondary_class = link_type.secondary_class
|
|
link_table = link_type.link_table
|
|
|
|
# Get field names
|
|
primary_id_field = f"{primary_class.__name__.lower()}_id"
|
|
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
|
|
|
|
# Find existing link
|
|
statement = select(link_table).where(
|
|
and_(
|
|
getattr(link_table, primary_id_field) == primary_id,
|
|
getattr(link_table, secondary_id_field) == secondary_id,
|
|
)
|
|
)
|
|
|
|
if sequence is not None:
|
|
statement = statement.where(link_table.sequence == sequence)
|
|
|
|
existing_link = session.exec(statement).first()
|
|
|
|
if not existing_link:
|
|
return Response(message="Link does not exist", status=False)
|
|
|
|
deleted_sequence = existing_link.sequence
|
|
session.delete(existing_link)
|
|
|
|
# Reorder sequences for remaining links
|
|
remaining_links = session.exec(
|
|
select(link_table)
|
|
.where(getattr(link_table, primary_id_field) == primary_id)
|
|
.where(link_table.sequence > deleted_sequence)
|
|
.order_by(link_table.sequence)
|
|
).all()
|
|
|
|
# Decrease sequence numbers to fill the gap
|
|
for link in remaining_links:
|
|
link.sequence -= 1
|
|
|
|
session.commit()
|
|
|
|
return Response(message="Entities unlinked successfully and sequences reordered", status=True)
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
return Response(message=f"Error unlinking entities: {str(e)}", status=False)
|
|
|
|
def get_linked_entities(
|
|
self,
|
|
link_type: LinkTypes,
|
|
primary_id: int,
|
|
return_json: bool = False,
|
|
):
|
|
"""Get linked entities based on link type and primary ID, ordered by sequence."""
|
|
with Session(self.engine) as session:
|
|
try:
|
|
# Get classes from LinkTypes
|
|
primary_class = link_type.primary_class
|
|
secondary_class = link_type.secondary_class
|
|
link_table = link_type.link_table
|
|
|
|
# Get field names
|
|
primary_id_field = f"{primary_class.__name__.lower()}_id"
|
|
secondary_id_field = f"{secondary_class.__name__.lower()}_id"
|
|
|
|
# Query both link and entity, ordered by sequence
|
|
items = session.exec(
|
|
select(secondary_class)
|
|
.join(link_table, getattr(link_table, secondary_id_field) == secondary_class.id)
|
|
.where(getattr(link_table, primary_id_field) == primary_id)
|
|
.order_by(link_table.sequence)
|
|
).all()
|
|
|
|
result = [item.model_dump() if return_json else item for item in items]
|
|
|
|
return Response(message="Linked entities retrieved successfully", status=True, data=result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting linked entities: {str(e)}")
|
|
return Response(message=f"Error getting linked entities: {str(e)}", status=False, data=[])
|
|
|
|
# Add new close method
|
|
|
|
async def close(self):
|
|
"""Close database connections and cleanup resources"""
|
|
logger.info("Closing database connections...")
|
|
try:
|
|
# Dispose of the SQLAlchemy engine
|
|
self.engine.dispose()
|
|
logger.info("Database connections closed successfully")
|
|
except Exception as e:
|
|
logger.error(f"Error closing database connections: {str(e)}")
|
|
raise
|