autogen/python/packages/autogen-studio/autogenstudio/database/schema_manager.py

540 lines
18 KiB
Python

import os
from pathlib import Path
import shutil
from typing import Optional, Tuple, List
from loguru import logger
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from alembic.autogenerate import compare_metadata
from sqlalchemy import Engine
from sqlmodel import SQLModel
from alembic.util.exc import CommandError
class SchemaManager:
"""
Manages database schema validation and migrations using Alembic.
Provides automatic schema validation, migrations, and safe upgrades.
Args:
engine: SQLAlchemy engine instance
auto_upgrade: Whether to automatically upgrade schema when differences found
init_mode: Controls initialization behavior:
- "none": No automatic initialization (raises error if not set up)
- "auto": Initialize if not present (default)
- "force": Always reinitialize, removing existing configuration
"""
def __init__(
self,
engine: Engine,
base_dir: Optional[Path] = None,
auto_upgrade: bool = True,
init_mode: str = "auto"
):
if init_mode not in ["none", "auto", "force"]:
raise ValueError("init_mode must be one of: none, auto, force")
self.engine = engine
self.auto_upgrade = auto_upgrade
# Use provided base_dir or default to class file location
self.base_dir = base_dir or Path(__file__).parent
self.alembic_dir = self.base_dir / 'alembic'
self.alembic_ini_path = self.base_dir / 'alembic.ini'
# Create base directory if it doesn't exist
self.base_dir.mkdir(parents=True, exist_ok=True)
# Initialize based on mode
if init_mode == "force":
self._cleanup_existing_alembic()
self._initialize_alembic()
else:
try:
self._validate_alembic_setup()
logger.info("Using existing Alembic configuration")
# Update existing configuration
self._update_configuration()
except FileNotFoundError:
if init_mode == "none":
raise
logger.info("Initializing new Alembic configuration")
self._initialize_alembic()
def _update_configuration(self) -> None:
"""Updates existing Alembic configuration with current settings."""
logger.info("Updating existing Alembic configuration...")
# Update alembic.ini
config_content = self._generate_alembic_ini_content()
with open(self.alembic_ini_path, 'w') as f:
f.write(config_content)
# Update env.py
env_path = self.alembic_dir / 'env.py'
if env_path.exists():
self._update_env_py(env_path)
else:
self._create_minimal_env_py(env_path)
def _cleanup_existing_alembic(self) -> None:
"""
Safely removes existing Alembic configuration while preserving versions directory.
"""
logger.info(
"Cleaning up existing Alembic configuration while preserving versions...")
# Create a backup of versions directory if it exists
if self.alembic_dir.exists() and (self.alembic_dir / 'versions').exists():
logger.info("Preserving existing versions directory")
# Remove alembic directory contents EXCEPT versions
if self.alembic_dir.exists():
for item in self.alembic_dir.iterdir():
if item.name != 'versions':
try:
if item.is_dir():
shutil.rmtree(item)
logger.info(f"Removed directory: {item}")
else:
item.unlink()
logger.info(f"Removed file: {item}")
except Exception as e:
logger.error(f"Failed to remove {item}: {e}")
# Remove alembic.ini if it exists
if self.alembic_ini_path.exists():
try:
self.alembic_ini_path.unlink()
logger.info(
f"Removed existing alembic.ini: {self.alembic_ini_path}")
except Exception as e:
logger.error(f"Failed to remove alembic.ini: {e}")
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
"""
Ensures Alembic is properly set up, initializing if necessary.
Args:
force: If True, removes existing configuration and reinitializes
"""
try:
self._validate_alembic_setup()
if force:
logger.info(
"Force initialization requested. Cleaning up existing configuration...")
self._cleanup_existing_alembic()
self._initialize_alembic()
except FileNotFoundError:
logger.info("Alembic configuration not found. Initializing...")
if self.alembic_dir.exists():
logger.warning(
"Found existing alembic directory but missing configuration")
self._cleanup_existing_alembic()
self._initialize_alembic()
logger.info("Alembic initialization complete")
def _initialize_alembic(self) -> None:
logger.info("Initializing Alembic configuration...")
# Create directories first
self.alembic_dir.mkdir(exist_ok=True)
versions_dir = self.alembic_dir / 'versions'
versions_dir.mkdir(exist_ok=True)
# Create env.py BEFORE running command.init
env_path = self.alembic_dir / 'env.py'
if not env_path.exists():
self._create_minimal_env_py(env_path)
logger.info("Created new env.py")
# Write alembic.ini
config_content = self._generate_alembic_ini_content()
with open(self.alembic_ini_path, 'w') as f:
f.write(config_content)
logger.info("Created alembic.ini")
# Now run alembic init
try:
config = self.get_alembic_config()
command.init(config, str(self.alembic_dir))
logger.info("Initialized Alembic directory structure")
except CommandError as e:
if "already exists" not in str(e):
raise
def _create_minimal_env_py(self, env_path: Path) -> None:
"""Creates a minimal env.py file for Alembic."""
content = '''
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sqlmodel import SQLModel
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = SQLModel.metadata
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()'''
with open(env_path, 'w') as f:
f.write(content)
def _generate_alembic_ini_content(self) -> str:
"""
Generates content for alembic.ini file.
"""
return f"""
[alembic]
script_location = {self.alembic_dir}
sqlalchemy.url = {self.engine.url}
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
""".strip()
def _update_env_py(self, env_path: Path) -> None:
"""
Updates the env.py file to use SQLModel metadata.
"""
if not env_path.exists():
self._create_minimal_env_py(env_path)
return
try:
with open(env_path, 'r') as f:
content = f.read()
# Add SQLModel import
if "from sqlmodel import SQLModel" not in content:
content = "from sqlmodel import SQLModel\n" + content
# Replace target_metadata
content = content.replace(
"target_metadata = None",
"target_metadata = SQLModel.metadata"
)
# Add compare_type=True to context.configure
if "context.configure(" in content and "compare_type=True" not in content:
content = content.replace(
"context.configure(",
"context.configure(compare_type=True,"
)
with open(env_path, 'w') as f:
f.write(content)
logger.info("Updated env.py with SQLModel metadata")
except Exception as e:
logger.error(f"Failed to update env.py: {e}")
raise
# Fixed: use keyword-only argument
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
"""
Ensures Alembic is properly set up, initializing if necessary.
Args:
force: If True, removes existing configuration and reinitializes
"""
try:
self._validate_alembic_setup()
if force:
logger.info(
"Force initialization requested. Cleaning up existing configuration...")
self._cleanup_existing_alembic()
self._initialize_alembic()
except FileNotFoundError:
logger.info("Alembic configuration not found. Initializing...")
if self.alembic_dir.exists():
logger.warning(
"Found existing alembic directory but missing configuration")
self._cleanup_existing_alembic()
self._initialize_alembic()
logger.info("Alembic initialization complete")
def _validate_alembic_setup(self) -> None:
"""Validates that Alembic is properly configured."""
required_files = [
self.alembic_ini_path,
self.alembic_dir / 'env.py',
self.alembic_dir / 'versions'
]
missing = [f for f in required_files if not f.exists()]
if missing:
raise FileNotFoundError(
f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}"
)
def get_alembic_config(self) -> Config:
"""
Gets Alembic configuration.
Returns:
Config: Alembic Config object
Raises:
FileNotFoundError: If alembic.ini cannot be found
"""
if not self.alembic_ini_path.exists():
raise FileNotFoundError("Could not find alembic.ini")
return Config(str(self.alembic_ini_path))
def get_current_revision(self) -> Optional[str]:
"""
Gets the current database revision.
Returns:
str: Current revision string or None if no revision
"""
with self.engine.connect() as conn:
context = MigrationContext.configure(conn)
return context.get_current_revision()
def get_head_revision(self) -> str:
"""
Gets the latest available revision.
Returns:
str: Head revision string
"""
config = self.get_alembic_config()
script = ScriptDirectory.from_config(config)
return script.get_current_head()
def get_schema_differences(self) -> List[tuple]:
"""
Detects differences between current database and models.
Returns:
List[tuple]: List of differences found
"""
with self.engine.connect() as conn:
context = MigrationContext.configure(conn)
diff = compare_metadata(context, SQLModel.metadata)
return list(diff)
def check_schema_status(self) -> Tuple[bool, str]:
"""
Checks if database schema matches current models and migrations.
Returns:
Tuple[bool, str]: (needs_upgrade, status_message)
"""
try:
current_rev = self.get_current_revision()
head_rev = self.get_head_revision()
if current_rev != head_rev:
return True, f"Database needs upgrade: {current_rev} -> {head_rev}"
differences = self.get_schema_differences()
if differences:
changes_desc = "\n".join(str(diff) for diff in differences)
return True, f"Unmigrated changes detected:\n{changes_desc}"
return False, "Database schema is up to date"
except Exception as e:
logger.error(f"Error checking schema status: {str(e)}")
return True, f"Error checking schema: {str(e)}"
def upgrade_schema(self, revision: str = "head") -> bool:
"""
Upgrades database schema to specified revision.
Args:
revision: Target revision (default: "head")
Returns:
bool: True if upgrade successful
"""
try:
config = self.get_alembic_config()
command.upgrade(config, revision)
logger.info(f"Schema upgraded successfully to {revision}")
return True
except Exception as e:
logger.error(f"Schema upgrade failed: {str(e)}")
return False
def check_and_upgrade(self) -> Tuple[bool, str]:
"""
Checks schema status and upgrades if necessary (and auto_upgrade is True).
Returns:
Tuple[bool, str]: (action_taken, status_message)
"""
needs_upgrade, status = self.check_schema_status()
if needs_upgrade:
if self.auto_upgrade:
if self.upgrade_schema():
return True, "Schema was automatically upgraded"
else:
return False, "Automatic schema upgrade failed"
else:
return False, f"Schema needs upgrade but auto_upgrade is disabled. Status: {status}"
return False, status
def generate_revision(self, message: str = "auto") -> Optional[str]:
"""
Generates new migration revision for current schema changes.
Args:
message: Revision message
Returns:
str: Revision ID if successful, None otherwise
"""
try:
config = self.get_alembic_config()
command.revision(
config,
message=message,
autogenerate=True
)
return self.get_head_revision()
except Exception as e:
logger.error(f"Failed to generate revision: {str(e)}")
return None
def get_pending_migrations(self) -> List[str]:
"""
Gets list of pending migrations that need to be applied.
Returns:
List[str]: List of pending migration revision IDs
"""
config = self.get_alembic_config()
script = ScriptDirectory.from_config(config)
current = self.get_current_revision()
head = self.get_head_revision()
if current == head:
return []
pending = []
for rev in script.iterate_revisions(current, head):
pending.append(rev.revision)
return pending
def print_status(self) -> None:
"""Prints current migration status information to logger."""
current = self.get_current_revision()
head = self.get_head_revision()
differences = self.get_schema_differences()
pending = self.get_pending_migrations()
logger.info("=== Database Schema Status ===")
logger.info(f"Current revision: {current}")
logger.info(f"Head revision: {head}")
logger.info(f"Pending migrations: {len(pending)}")
for rev in pending:
logger.info(f" - {rev}")
logger.info(f"Unmigrated changes: {len(differences)}")
for diff in differences:
logger.info(f" - {diff}")
def ensure_schema_up_to_date(self) -> bool:
"""
Ensures the database schema is up to date, generating and applying migrations if needed.
Returns:
bool: True if schema is up to date or was successfully updated
"""
try:
# Check for unmigrated changes
differences = self.get_schema_differences()
if differences:
# Generate new migration
revision = self.generate_revision("auto-generated")
if not revision:
return False
logger.info(f"Generated new migration: {revision}")
# Apply any pending migrations
upgraded, status = self.check_and_upgrade()
if not upgraded and "needs upgrade" in status.lower():
return False
return True
except Exception as e:
logger.error(f"Failed to ensure schema is up to date: {e}")
return False