mirror of https://github.com/microsoft/autogen.git
540 lines
18 KiB
Python
540 lines
18 KiB
Python
import os
|
|
from pathlib import Path
|
|
import shutil
|
|
from typing import Optional, Tuple, List
|
|
from loguru import logger
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from alembic.runtime.migration import MigrationContext
|
|
from alembic.script import ScriptDirectory
|
|
from alembic.autogenerate import compare_metadata
|
|
from sqlalchemy import Engine
|
|
from sqlmodel import SQLModel
|
|
from alembic.util.exc import CommandError
|
|
|
|
|
|
class SchemaManager:
|
|
"""
|
|
Manages database schema validation and migrations using Alembic.
|
|
Provides automatic schema validation, migrations, and safe upgrades.
|
|
|
|
Args:
|
|
engine: SQLAlchemy engine instance
|
|
auto_upgrade: Whether to automatically upgrade schema when differences found
|
|
init_mode: Controls initialization behavior:
|
|
- "none": No automatic initialization (raises error if not set up)
|
|
- "auto": Initialize if not present (default)
|
|
- "force": Always reinitialize, removing existing configuration
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Engine,
|
|
base_dir: Optional[Path] = None,
|
|
auto_upgrade: bool = True,
|
|
init_mode: str = "auto"
|
|
):
|
|
if init_mode not in ["none", "auto", "force"]:
|
|
raise ValueError("init_mode must be one of: none, auto, force")
|
|
|
|
self.engine = engine
|
|
self.auto_upgrade = auto_upgrade
|
|
|
|
# Use provided base_dir or default to class file location
|
|
self.base_dir = base_dir or Path(__file__).parent
|
|
self.alembic_dir = self.base_dir / 'alembic'
|
|
self.alembic_ini_path = self.base_dir / 'alembic.ini'
|
|
|
|
# Create base directory if it doesn't exist
|
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Initialize based on mode
|
|
if init_mode == "force":
|
|
self._cleanup_existing_alembic()
|
|
self._initialize_alembic()
|
|
else:
|
|
try:
|
|
self._validate_alembic_setup()
|
|
logger.info("Using existing Alembic configuration")
|
|
# Update existing configuration
|
|
self._update_configuration()
|
|
except FileNotFoundError:
|
|
if init_mode == "none":
|
|
raise
|
|
logger.info("Initializing new Alembic configuration")
|
|
self._initialize_alembic()
|
|
|
|
def _update_configuration(self) -> None:
|
|
"""Updates existing Alembic configuration with current settings."""
|
|
logger.info("Updating existing Alembic configuration...")
|
|
|
|
# Update alembic.ini
|
|
config_content = self._generate_alembic_ini_content()
|
|
with open(self.alembic_ini_path, 'w') as f:
|
|
f.write(config_content)
|
|
|
|
# Update env.py
|
|
env_path = self.alembic_dir / 'env.py'
|
|
if env_path.exists():
|
|
self._update_env_py(env_path)
|
|
else:
|
|
self._create_minimal_env_py(env_path)
|
|
|
|
def _cleanup_existing_alembic(self) -> None:
|
|
"""
|
|
Safely removes existing Alembic configuration while preserving versions directory.
|
|
"""
|
|
logger.info(
|
|
"Cleaning up existing Alembic configuration while preserving versions...")
|
|
|
|
# Create a backup of versions directory if it exists
|
|
if self.alembic_dir.exists() and (self.alembic_dir / 'versions').exists():
|
|
logger.info("Preserving existing versions directory")
|
|
|
|
# Remove alembic directory contents EXCEPT versions
|
|
if self.alembic_dir.exists():
|
|
for item in self.alembic_dir.iterdir():
|
|
if item.name != 'versions':
|
|
try:
|
|
if item.is_dir():
|
|
shutil.rmtree(item)
|
|
logger.info(f"Removed directory: {item}")
|
|
else:
|
|
item.unlink()
|
|
logger.info(f"Removed file: {item}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove {item}: {e}")
|
|
|
|
# Remove alembic.ini if it exists
|
|
if self.alembic_ini_path.exists():
|
|
try:
|
|
self.alembic_ini_path.unlink()
|
|
logger.info(
|
|
f"Removed existing alembic.ini: {self.alembic_ini_path}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove alembic.ini: {e}")
|
|
|
|
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
|
|
"""
|
|
Ensures Alembic is properly set up, initializing if necessary.
|
|
|
|
Args:
|
|
force: If True, removes existing configuration and reinitializes
|
|
"""
|
|
try:
|
|
self._validate_alembic_setup()
|
|
if force:
|
|
logger.info(
|
|
"Force initialization requested. Cleaning up existing configuration...")
|
|
self._cleanup_existing_alembic()
|
|
self._initialize_alembic()
|
|
except FileNotFoundError:
|
|
logger.info("Alembic configuration not found. Initializing...")
|
|
if self.alembic_dir.exists():
|
|
logger.warning(
|
|
"Found existing alembic directory but missing configuration")
|
|
self._cleanup_existing_alembic()
|
|
self._initialize_alembic()
|
|
logger.info("Alembic initialization complete")
|
|
|
|
def _initialize_alembic(self) -> None:
|
|
logger.info("Initializing Alembic configuration...")
|
|
|
|
# Create directories first
|
|
self.alembic_dir.mkdir(exist_ok=True)
|
|
versions_dir = self.alembic_dir / 'versions'
|
|
versions_dir.mkdir(exist_ok=True)
|
|
|
|
# Create env.py BEFORE running command.init
|
|
env_path = self.alembic_dir / 'env.py'
|
|
if not env_path.exists():
|
|
self._create_minimal_env_py(env_path)
|
|
logger.info("Created new env.py")
|
|
|
|
# Write alembic.ini
|
|
config_content = self._generate_alembic_ini_content()
|
|
with open(self.alembic_ini_path, 'w') as f:
|
|
f.write(config_content)
|
|
logger.info("Created alembic.ini")
|
|
|
|
# Now run alembic init
|
|
try:
|
|
config = self.get_alembic_config()
|
|
command.init(config, str(self.alembic_dir))
|
|
logger.info("Initialized Alembic directory structure")
|
|
except CommandError as e:
|
|
if "already exists" not in str(e):
|
|
raise
|
|
|
|
def _create_minimal_env_py(self, env_path: Path) -> None:
|
|
"""Creates a minimal env.py file for Alembic."""
|
|
content = '''
|
|
from logging.config import fileConfig
|
|
from sqlalchemy import engine_from_config
|
|
from sqlalchemy import pool
|
|
from alembic import context
|
|
from sqlmodel import SQLModel
|
|
|
|
config = context.config
|
|
if config.config_file_name is not None:
|
|
fileConfig(config.config_file_name)
|
|
|
|
target_metadata = SQLModel.metadata
|
|
|
|
def run_migrations_offline() -> None:
|
|
url = config.get_main_option("sqlalchemy.url")
|
|
context.configure(
|
|
url=url,
|
|
target_metadata=target_metadata,
|
|
literal_binds=True,
|
|
dialect_opts={"paramstyle": "named"},
|
|
compare_type=True
|
|
)
|
|
with context.begin_transaction():
|
|
context.run_migrations()
|
|
|
|
def run_migrations_online() -> None:
|
|
connectable = engine_from_config(
|
|
config.get_section(config.config_ini_section),
|
|
prefix="sqlalchemy.",
|
|
poolclass=pool.NullPool,
|
|
)
|
|
with connectable.connect() as connection:
|
|
context.configure(
|
|
connection=connection,
|
|
target_metadata=target_metadata,
|
|
compare_type=True
|
|
)
|
|
with context.begin_transaction():
|
|
context.run_migrations()
|
|
|
|
if context.is_offline_mode():
|
|
run_migrations_offline()
|
|
else:
|
|
run_migrations_online()'''
|
|
|
|
with open(env_path, 'w') as f:
|
|
f.write(content)
|
|
|
|
def _generate_alembic_ini_content(self) -> str:
|
|
"""
|
|
Generates content for alembic.ini file.
|
|
"""
|
|
return f"""
|
|
[alembic]
|
|
script_location = {self.alembic_dir}
|
|
sqlalchemy.url = {self.engine.url}
|
|
|
|
[loggers]
|
|
keys = root,sqlalchemy,alembic
|
|
|
|
[handlers]
|
|
keys = console
|
|
|
|
[formatters]
|
|
keys = generic
|
|
|
|
[logger_root]
|
|
level = WARN
|
|
handlers = console
|
|
qualname =
|
|
|
|
[logger_sqlalchemy]
|
|
level = WARN
|
|
handlers =
|
|
qualname = sqlalchemy.engine
|
|
|
|
[logger_alembic]
|
|
level = INFO
|
|
handlers =
|
|
qualname = alembic
|
|
|
|
[handler_console]
|
|
class = StreamHandler
|
|
args = (sys.stderr,)
|
|
level = NOTSET
|
|
formatter = generic
|
|
|
|
[formatter_generic]
|
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
|
datefmt = %H:%M:%S
|
|
""".strip()
|
|
|
|
def _update_env_py(self, env_path: Path) -> None:
|
|
"""
|
|
Updates the env.py file to use SQLModel metadata.
|
|
"""
|
|
if not env_path.exists():
|
|
self._create_minimal_env_py(env_path)
|
|
return
|
|
try:
|
|
with open(env_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Add SQLModel import
|
|
if "from sqlmodel import SQLModel" not in content:
|
|
content = "from sqlmodel import SQLModel\n" + content
|
|
|
|
# Replace target_metadata
|
|
content = content.replace(
|
|
"target_metadata = None",
|
|
"target_metadata = SQLModel.metadata"
|
|
)
|
|
|
|
# Add compare_type=True to context.configure
|
|
if "context.configure(" in content and "compare_type=True" not in content:
|
|
content = content.replace(
|
|
"context.configure(",
|
|
"context.configure(compare_type=True,"
|
|
)
|
|
|
|
with open(env_path, 'w') as f:
|
|
f.write(content)
|
|
|
|
logger.info("Updated env.py with SQLModel metadata")
|
|
except Exception as e:
|
|
logger.error(f"Failed to update env.py: {e}")
|
|
raise
|
|
|
|
# Fixed: use keyword-only argument
|
|
def _ensure_alembic_setup(self, *, force: bool = False) -> None:
|
|
"""
|
|
Ensures Alembic is properly set up, initializing if necessary.
|
|
|
|
Args:
|
|
force: If True, removes existing configuration and reinitializes
|
|
"""
|
|
try:
|
|
self._validate_alembic_setup()
|
|
if force:
|
|
logger.info(
|
|
"Force initialization requested. Cleaning up existing configuration...")
|
|
self._cleanup_existing_alembic()
|
|
self._initialize_alembic()
|
|
except FileNotFoundError:
|
|
logger.info("Alembic configuration not found. Initializing...")
|
|
if self.alembic_dir.exists():
|
|
logger.warning(
|
|
"Found existing alembic directory but missing configuration")
|
|
self._cleanup_existing_alembic()
|
|
self._initialize_alembic()
|
|
logger.info("Alembic initialization complete")
|
|
|
|
def _validate_alembic_setup(self) -> None:
|
|
"""Validates that Alembic is properly configured."""
|
|
required_files = [
|
|
self.alembic_ini_path,
|
|
self.alembic_dir / 'env.py',
|
|
self.alembic_dir / 'versions'
|
|
]
|
|
|
|
missing = [f for f in required_files if not f.exists()]
|
|
if missing:
|
|
raise FileNotFoundError(
|
|
f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}"
|
|
)
|
|
|
|
def get_alembic_config(self) -> Config:
|
|
"""
|
|
Gets Alembic configuration.
|
|
|
|
Returns:
|
|
Config: Alembic Config object
|
|
|
|
Raises:
|
|
FileNotFoundError: If alembic.ini cannot be found
|
|
"""
|
|
if not self.alembic_ini_path.exists():
|
|
raise FileNotFoundError("Could not find alembic.ini")
|
|
|
|
return Config(str(self.alembic_ini_path))
|
|
|
|
def get_current_revision(self) -> Optional[str]:
|
|
"""
|
|
Gets the current database revision.
|
|
|
|
Returns:
|
|
str: Current revision string or None if no revision
|
|
"""
|
|
with self.engine.connect() as conn:
|
|
context = MigrationContext.configure(conn)
|
|
return context.get_current_revision()
|
|
|
|
def get_head_revision(self) -> str:
|
|
"""
|
|
Gets the latest available revision.
|
|
|
|
Returns:
|
|
str: Head revision string
|
|
"""
|
|
config = self.get_alembic_config()
|
|
script = ScriptDirectory.from_config(config)
|
|
return script.get_current_head()
|
|
|
|
def get_schema_differences(self) -> List[tuple]:
|
|
"""
|
|
Detects differences between current database and models.
|
|
|
|
Returns:
|
|
List[tuple]: List of differences found
|
|
"""
|
|
with self.engine.connect() as conn:
|
|
context = MigrationContext.configure(conn)
|
|
diff = compare_metadata(context, SQLModel.metadata)
|
|
return list(diff)
|
|
|
|
def check_schema_status(self) -> Tuple[bool, str]:
|
|
"""
|
|
Checks if database schema matches current models and migrations.
|
|
|
|
Returns:
|
|
Tuple[bool, str]: (needs_upgrade, status_message)
|
|
"""
|
|
try:
|
|
current_rev = self.get_current_revision()
|
|
head_rev = self.get_head_revision()
|
|
|
|
if current_rev != head_rev:
|
|
return True, f"Database needs upgrade: {current_rev} -> {head_rev}"
|
|
|
|
differences = self.get_schema_differences()
|
|
if differences:
|
|
changes_desc = "\n".join(str(diff) for diff in differences)
|
|
return True, f"Unmigrated changes detected:\n{changes_desc}"
|
|
|
|
return False, "Database schema is up to date"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking schema status: {str(e)}")
|
|
return True, f"Error checking schema: {str(e)}"
|
|
|
|
def upgrade_schema(self, revision: str = "head") -> bool:
|
|
"""
|
|
Upgrades database schema to specified revision.
|
|
|
|
Args:
|
|
revision: Target revision (default: "head")
|
|
|
|
Returns:
|
|
bool: True if upgrade successful
|
|
"""
|
|
try:
|
|
config = self.get_alembic_config()
|
|
command.upgrade(config, revision)
|
|
logger.info(f"Schema upgraded successfully to {revision}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Schema upgrade failed: {str(e)}")
|
|
return False
|
|
|
|
def check_and_upgrade(self) -> Tuple[bool, str]:
|
|
"""
|
|
Checks schema status and upgrades if necessary (and auto_upgrade is True).
|
|
|
|
Returns:
|
|
Tuple[bool, str]: (action_taken, status_message)
|
|
"""
|
|
needs_upgrade, status = self.check_schema_status()
|
|
|
|
if needs_upgrade:
|
|
if self.auto_upgrade:
|
|
if self.upgrade_schema():
|
|
return True, "Schema was automatically upgraded"
|
|
else:
|
|
return False, "Automatic schema upgrade failed"
|
|
else:
|
|
return False, f"Schema needs upgrade but auto_upgrade is disabled. Status: {status}"
|
|
|
|
return False, status
|
|
|
|
def generate_revision(self, message: str = "auto") -> Optional[str]:
|
|
"""
|
|
Generates new migration revision for current schema changes.
|
|
|
|
Args:
|
|
message: Revision message
|
|
|
|
Returns:
|
|
str: Revision ID if successful, None otherwise
|
|
"""
|
|
try:
|
|
config = self.get_alembic_config()
|
|
command.revision(
|
|
config,
|
|
message=message,
|
|
autogenerate=True
|
|
)
|
|
return self.get_head_revision()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate revision: {str(e)}")
|
|
return None
|
|
|
|
def get_pending_migrations(self) -> List[str]:
|
|
"""
|
|
Gets list of pending migrations that need to be applied.
|
|
|
|
Returns:
|
|
List[str]: List of pending migration revision IDs
|
|
"""
|
|
config = self.get_alembic_config()
|
|
script = ScriptDirectory.from_config(config)
|
|
|
|
current = self.get_current_revision()
|
|
head = self.get_head_revision()
|
|
|
|
if current == head:
|
|
return []
|
|
|
|
pending = []
|
|
for rev in script.iterate_revisions(current, head):
|
|
pending.append(rev.revision)
|
|
|
|
return pending
|
|
|
|
def print_status(self) -> None:
|
|
"""Prints current migration status information to logger."""
|
|
current = self.get_current_revision()
|
|
head = self.get_head_revision()
|
|
differences = self.get_schema_differences()
|
|
pending = self.get_pending_migrations()
|
|
|
|
logger.info("=== Database Schema Status ===")
|
|
logger.info(f"Current revision: {current}")
|
|
logger.info(f"Head revision: {head}")
|
|
logger.info(f"Pending migrations: {len(pending)}")
|
|
for rev in pending:
|
|
logger.info(f" - {rev}")
|
|
logger.info(f"Unmigrated changes: {len(differences)}")
|
|
for diff in differences:
|
|
logger.info(f" - {diff}")
|
|
|
|
def ensure_schema_up_to_date(self) -> bool:
|
|
"""
|
|
Ensures the database schema is up to date, generating and applying migrations if needed.
|
|
|
|
Returns:
|
|
bool: True if schema is up to date or was successfully updated
|
|
"""
|
|
try:
|
|
# Check for unmigrated changes
|
|
differences = self.get_schema_differences()
|
|
if differences:
|
|
# Generate new migration
|
|
revision = self.generate_revision("auto-generated")
|
|
if not revision:
|
|
return False
|
|
logger.info(f"Generated new migration: {revision}")
|
|
|
|
# Apply any pending migrations
|
|
upgraded, status = self.check_and_upgrade()
|
|
if not upgraded and "needs upgrade" in status.lower():
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to ensure schema is up to date: {e}")
|
|
return False
|