autogen/python/packages/autogen-studio/autogenstudio/web/auth/manager.py

151 lines
6.1 KiB
Python

import os
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
import jwt
import yaml
from fastapi import Request
from loguru import logger
from typing_extensions import Self
from .exceptions import ConfigurationException, InvalidTokenException, MissingTokenException
from .models import AuthConfig, User
from .providers import AuthProvider, FirebaseAuthProvider, GithubAuthProvider, MSALAuthProvider, NoAuthProvider
class AuthManager:
"""
Manages authentication for the application.
Handles token creation, validation, and provider selection.
"""
def __init__(self, config: AuthConfig):
"""Initialize the auth manager with configuration."""
self.config = config
self.provider = self._create_provider()
logger.info(f"Initialized auth manager with provider: {config.type}")
def _create_provider(self) -> AuthProvider:
"""Create the appropriate auth provider based on config."""
try:
if self.config.type == "github":
return GithubAuthProvider(self.config)
elif self.config.type == "msal":
return MSALAuthProvider(self.config)
elif self.config.type == "firebase":
return FirebaseAuthProvider(self.config)
else:
return NoAuthProvider()
except Exception as e:
logger.error(f"Failed to create auth provider: {str(e)}")
# Fall back to no auth if provider creation fails
return NoAuthProvider()
def create_token(self, user: User) -> str:
"""Create a JWT token for authenticated user."""
if not self.config.jwt_secret:
logger.warning("JWT secret not configured, using insecure token")
return "dummy_token_" + user.id
expiry = datetime.utcnow() + timedelta(minutes=self.config.token_expiry_minutes)
payload = {
"sub": user.id,
"name": user.name,
"email": user.email,
"provider": user.provider,
"roles": user.roles,
"exp": expiry,
}
return jwt.encode(payload, self.config.jwt_secret, algorithm="HS256")
async def authenticate_request(self, request: Request) -> User:
"""Authenticate a request and return user information."""
# Check if path should be excluded from auth
# print("************ authenticating request ************", request.url.path, self.config.type )
if request.url.path in self.config.exclude_paths:
return User(id="guestuser@gmail.com", name="Default User", provider="none")
if self.config.type == "none":
# No auth mode - return default user
return User(id="guestuser@gmail.com", name="Default User", provider="none")
# Extract token from Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise MissingTokenException()
token = auth_header.replace("Bearer ", "")
try:
if not self.config.jwt_secret:
# For development with no JWT secret
logger.warning("JWT secret not configured, accepting all tokens")
return User(id="guestuser@gmail.com", name="Default User", provider="none")
# Decode and validate JWT
payload = jwt.decode(token, self.config.jwt_secret, algorithms=["HS256"])
# Create User object from token payload
return User(
id=payload.get("sub"),
name=payload.get("name", "Unknown User"),
email=payload.get("email"),
provider=payload.get("provider", "jwt"),
roles=payload.get("roles", ["user"]),
)
except jwt.ExpiredSignatureError:
logger.warning(f"Expired token received: {token[:10]}...")
raise InvalidTokenException() from InvalidTokenException()
except jwt.InvalidTokenError:
logger.warning(f"Invalid token received: {token[:10]}...")
raise InvalidTokenException() from InvalidTokenException()
def is_valid_token(self, token: str) -> bool:
"""Check if a JWT token is valid."""
if not self.config.jwt_secret:
return True # No validation in dev mode
try:
jwt.decode(token, self.config.jwt_secret, algorithms=["HS256"])
return True
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
return False
@classmethod
def from_yaml(cls, yaml_path: str) -> Self:
"""Create AuthManager from YAML config file."""
try:
with open(yaml_path, "r") as f:
config_data = yaml.safe_load(f)
config = AuthConfig(**config_data)
return cls(config)
except Exception as e:
logger.error(f"Failed to load auth config from {yaml_path}: {str(e)}")
raise ConfigurationException(f"Failed to load auth config: {str(e)}") from ConfigurationException()
@classmethod
def from_env(cls) -> Self:
"""Create AuthManager from environment variables."""
auth_type = os.environ.get("AUTOGENSTUDIO_AUTH_TYPE", "none")
config_dict = {
"type": auth_type,
"jwt_secret": os.environ.get("AUTOGENSTUDIO_JWT_SECRET"),
"token_expiry_minutes": int(os.environ.get("AUTOGENSTUDIO_TOKEN_EXPIRY", "60")),
}
# Add provider-specific config based on the auth type
if auth_type == "github":
config_dict["github"] = {
"client_id": os.environ.get("AUTOGENSTUDIO_GITHUB_CLIENT_ID", ""),
"client_secret": os.environ.get("AUTOGENSTUDIO_GITHUB_CLIENT_SECRET", ""),
"callback_url": os.environ.get("AUTOGENSTUDIO_GITHUB_CALLBACK_URL", ""),
"scopes": os.environ.get("AUTOGENSTUDIO_GITHUB_SCOPES", "user:email").split(","),
}
# Add other provider config parsing here
config = AuthConfig(**config_dict)
return cls(config)