427 lines
14 KiB
Python
427 lines
14 KiB
Python
import logging
|
|
from typing import TYPE_CHECKING, Annotated, Literal
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
|
from pydantic import BaseModel, Field, SecretStr
|
|
|
|
from backend.data.graph import set_node_webhook
|
|
from backend.data.integrations import (
|
|
WebhookEvent,
|
|
get_all_webhooks_by_creds,
|
|
get_webhook,
|
|
publish_webhook_event,
|
|
wait_for_webhook_event,
|
|
)
|
|
from backend.data.model import (
|
|
APIKeyCredentials,
|
|
Credentials,
|
|
CredentialsType,
|
|
OAuth2Credentials,
|
|
)
|
|
from backend.executor.manager import ExecutionManager
|
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
from backend.integrations.oauth import HANDLERS_BY_NAME
|
|
from backend.integrations.providers import ProviderName
|
|
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
|
from backend.util.exceptions import NeedConfirmation
|
|
from backend.util.service import get_service_client
|
|
from backend.util.settings import Settings
|
|
|
|
if TYPE_CHECKING:
|
|
from backend.integrations.oauth import BaseOAuthHandler
|
|
|
|
from ..utils import get_user_id
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = Settings()
|
|
router = APIRouter()
|
|
|
|
creds_manager = IntegrationCredentialsManager()
|
|
|
|
|
|
class LoginResponse(BaseModel):
|
|
login_url: str
|
|
state_token: str
|
|
|
|
|
|
@router.get("/{provider}/login")
|
|
def login(
|
|
provider: Annotated[
|
|
ProviderName, Path(title="The provider to initiate an OAuth flow for")
|
|
],
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
request: Request,
|
|
scopes: Annotated[
|
|
str, Query(title="Comma-separated list of authorization scopes")
|
|
] = "",
|
|
) -> LoginResponse:
|
|
handler = _get_provider_oauth_handler(request, provider)
|
|
|
|
requested_scopes = scopes.split(",") if scopes else []
|
|
|
|
# Generate and store a secure random state token along with the scopes
|
|
state_token, code_challenge = creds_manager.store.store_state_token(
|
|
user_id, provider, requested_scopes
|
|
)
|
|
login_url = handler.get_login_url(
|
|
requested_scopes, state_token, code_challenge=code_challenge
|
|
)
|
|
|
|
return LoginResponse(login_url=login_url, state_token=state_token)
|
|
|
|
|
|
class CredentialsMetaResponse(BaseModel):
|
|
id: str
|
|
provider: str
|
|
type: CredentialsType
|
|
title: str | None
|
|
scopes: list[str] | None
|
|
username: str | None
|
|
|
|
|
|
@router.post("/{provider}/callback")
|
|
def callback(
|
|
provider: Annotated[
|
|
ProviderName, Path(title="The target provider for this OAuth exchange")
|
|
],
|
|
code: Annotated[str, Body(title="Authorization code acquired by user login")],
|
|
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
request: Request,
|
|
) -> CredentialsMetaResponse:
|
|
logger.debug(f"Received OAuth callback for provider: {provider}")
|
|
handler = _get_provider_oauth_handler(request, provider)
|
|
|
|
# Verify the state token
|
|
valid_state = creds_manager.store.verify_state_token(user_id, state_token, provider)
|
|
|
|
if not valid_state:
|
|
logger.warning(f"Invalid or expired state token for user {user_id}")
|
|
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
|
try:
|
|
scopes = valid_state.scopes
|
|
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
|
|
|
scopes = handler.handle_default_scopes(scopes)
|
|
|
|
credentials = handler.exchange_code_for_tokens(
|
|
code, scopes, valid_state.code_verifier
|
|
)
|
|
|
|
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
|
|
|
|
# Check if the granted scopes are sufficient for the requested scopes
|
|
if not set(scopes).issubset(set(credentials.scopes)):
|
|
# For now, we'll just log the warning and continue
|
|
logger.warning(
|
|
f"Granted scopes {credentials.scopes} for provider {provider.value} "
|
|
f"do not include all requested scopes {scopes}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}")
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
|
|
)
|
|
|
|
# TODO: Allow specifying `title` to set on `credentials`
|
|
creds_manager.create(user_id, credentials)
|
|
|
|
logger.debug(
|
|
f"Successfully processed OAuth callback for user {user_id} "
|
|
f"and provider {provider.value}"
|
|
)
|
|
return CredentialsMetaResponse(
|
|
id=credentials.id,
|
|
provider=credentials.provider,
|
|
type=credentials.type,
|
|
title=credentials.title,
|
|
scopes=credentials.scopes,
|
|
username=credentials.username,
|
|
)
|
|
|
|
|
|
@router.get("/credentials")
|
|
def list_credentials(
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
) -> list[CredentialsMetaResponse]:
|
|
credentials = creds_manager.store.get_all_creds(user_id)
|
|
return [
|
|
CredentialsMetaResponse(
|
|
id=cred.id,
|
|
provider=cred.provider,
|
|
type=cred.type,
|
|
title=cred.title,
|
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
|
)
|
|
for cred in credentials
|
|
]
|
|
|
|
|
|
@router.get("/{provider}/credentials")
|
|
def list_credentials_by_provider(
|
|
provider: Annotated[
|
|
ProviderName, Path(title="The provider to list credentials for")
|
|
],
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
) -> list[CredentialsMetaResponse]:
|
|
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
|
|
return [
|
|
CredentialsMetaResponse(
|
|
id=cred.id,
|
|
provider=cred.provider,
|
|
type=cred.type,
|
|
title=cred.title,
|
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
|
)
|
|
for cred in credentials
|
|
]
|
|
|
|
|
|
@router.get("/{provider}/credentials/{cred_id}")
|
|
def get_credential(
|
|
provider: Annotated[
|
|
ProviderName, Path(title="The provider to retrieve credentials for")
|
|
],
|
|
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
) -> Credentials:
|
|
credential = creds_manager.get(user_id, cred_id)
|
|
if not credential:
|
|
raise HTTPException(status_code=404, detail="Credentials not found")
|
|
if credential.provider != provider:
|
|
raise HTTPException(
|
|
status_code=404, detail="Credentials do not match the specified provider"
|
|
)
|
|
return credential
|
|
|
|
|
|
@router.post("/{provider}/credentials", status_code=201)
|
|
def create_api_key_credentials(
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
provider: Annotated[
|
|
ProviderName, Path(title="The provider to create credentials for")
|
|
],
|
|
api_key: Annotated[str, Body(title="The API key to store")],
|
|
title: Annotated[str, Body(title="Optional title for the credentials")],
|
|
expires_at: Annotated[
|
|
int | None, Body(title="Unix timestamp when the key expires")
|
|
] = None,
|
|
) -> APIKeyCredentials:
|
|
new_credentials = APIKeyCredentials(
|
|
provider=provider,
|
|
api_key=SecretStr(api_key),
|
|
title=title,
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
try:
|
|
creds_manager.create(user_id, new_credentials)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
|
)
|
|
return new_credentials
|
|
|
|
|
|
class CredentialsDeletionResponse(BaseModel):
|
|
deleted: Literal[True] = True
|
|
revoked: bool | None = Field(
|
|
description="Indicates whether the credentials were also revoked by their "
|
|
"provider. `None`/`null` if not applicable, e.g. when deleting "
|
|
"non-revocable credentials such as API keys."
|
|
)
|
|
|
|
|
|
class CredentialsDeletionNeedsConfirmationResponse(BaseModel):
|
|
deleted: Literal[False] = False
|
|
need_confirmation: Literal[True] = True
|
|
message: str
|
|
|
|
|
|
@router.delete("/{provider}/credentials/{cred_id}")
|
|
async def delete_credentials(
|
|
request: Request,
|
|
provider: Annotated[
|
|
ProviderName, Path(title="The provider to delete credentials for")
|
|
],
|
|
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
|
user_id: Annotated[str, Depends(get_user_id)],
|
|
force: Annotated[
|
|
bool, Query(title="Whether to proceed if any linked webhooks are still in use")
|
|
] = False,
|
|
) -> CredentialsDeletionResponse | CredentialsDeletionNeedsConfirmationResponse:
|
|
creds = creds_manager.store.get_creds_by_id(user_id, cred_id)
|
|
if not creds:
|
|
raise HTTPException(status_code=404, detail="Credentials not found")
|
|
if creds.provider != provider:
|
|
raise HTTPException(
|
|
status_code=404, detail="Credentials do not match the specified provider"
|
|
)
|
|
|
|
try:
|
|
await remove_all_webhooks_for_credentials(creds, force)
|
|
except NeedConfirmation as e:
|
|
return CredentialsDeletionNeedsConfirmationResponse(message=str(e))
|
|
|
|
creds_manager.delete(user_id, cred_id)
|
|
|
|
tokens_revoked = None
|
|
if isinstance(creds, OAuth2Credentials):
|
|
handler = _get_provider_oauth_handler(request, provider)
|
|
tokens_revoked = handler.revoke_tokens(creds)
|
|
|
|
return CredentialsDeletionResponse(revoked=tokens_revoked)
|
|
|
|
|
|
# ------------------------- WEBHOOK STUFF -------------------------- #
|
|
|
|
|
|
# ⚠️ Note
|
|
# No user auth check because this endpoint is for webhook ingress and relies on
|
|
# validation by the provider-specific `WebhooksManager`.
|
|
@router.post("/{provider}/webhooks/{webhook_id}/ingress")
|
|
async def webhook_ingress_generic(
|
|
request: Request,
|
|
provider: Annotated[
|
|
ProviderName, Path(title="Provider where the webhook was registered")
|
|
],
|
|
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
|
):
|
|
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
|
|
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
|
webhook = await get_webhook(webhook_id)
|
|
logger.debug(f"Webhook #{webhook_id}: {webhook}")
|
|
payload, event_type = await webhook_manager.validate_payload(webhook, request)
|
|
logger.debug(
|
|
f"Validated {provider.value} {webhook.webhook_type} {event_type} event "
|
|
f"with payload {payload}"
|
|
)
|
|
|
|
webhook_event = WebhookEvent(
|
|
provider=provider,
|
|
webhook_id=webhook_id,
|
|
event_type=event_type,
|
|
payload=payload,
|
|
)
|
|
await publish_webhook_event(webhook_event)
|
|
logger.debug(f"Webhook event published: {webhook_event}")
|
|
|
|
if not webhook.attached_nodes:
|
|
return
|
|
|
|
executor = get_service_client(ExecutionManager)
|
|
for node in webhook.attached_nodes:
|
|
logger.debug(f"Webhook-attached node: {node}")
|
|
if not node.is_triggered_by_event_type(event_type):
|
|
logger.debug(f"Node #{node.id} doesn't trigger on event {event_type}")
|
|
continue
|
|
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
|
executor.add_execution(
|
|
graph_id=node.graph_id,
|
|
graph_version=node.graph_version,
|
|
data={f"webhook_{webhook_id}_payload": payload},
|
|
user_id=webhook.user_id,
|
|
)
|
|
|
|
|
|
@router.post("/webhooks/{webhook_id}/ping")
|
|
async def webhook_ping(
|
|
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
|
|
user_id: Annotated[str, Depends(get_user_id)], # require auth
|
|
):
|
|
webhook = await get_webhook(webhook_id)
|
|
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[webhook.provider]()
|
|
|
|
credentials = (
|
|
creds_manager.get(user_id, webhook.credentials_id)
|
|
if webhook.credentials_id
|
|
else None
|
|
)
|
|
try:
|
|
await webhook_manager.trigger_ping(webhook, credentials)
|
|
except NotImplementedError:
|
|
return False
|
|
|
|
if not await wait_for_webhook_event(webhook_id, event_type="ping", timeout=10):
|
|
raise HTTPException(status_code=504, detail="Webhook ping timed out")
|
|
|
|
return True
|
|
|
|
|
|
# --------------------------- UTILITIES ---------------------------- #
|
|
|
|
|
|
async def remove_all_webhooks_for_credentials(
|
|
credentials: Credentials, force: bool = False
|
|
) -> None:
|
|
"""
|
|
Remove and deregister all webhooks that were registered using the given credentials.
|
|
|
|
Params:
|
|
credentials: The credentials for which to remove the associated webhooks.
|
|
force: Whether to proceed if any of the webhooks are still in use.
|
|
|
|
Raises:
|
|
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
|
"""
|
|
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
|
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
|
|
if webhooks:
|
|
logger.error(
|
|
f"Credentials #{credentials.id} for provider {credentials.provider} "
|
|
f"are attached to {len(webhooks)} webhooks, "
|
|
f"but there is no available WebhooksHandler for {credentials.provider}"
|
|
)
|
|
return
|
|
if any(w.attached_nodes for w in webhooks) and not force:
|
|
raise NeedConfirmation(
|
|
"Some webhooks linked to these credentials are still in use by an agent"
|
|
)
|
|
for webhook in webhooks:
|
|
# Unlink all nodes
|
|
for node in webhook.attached_nodes or []:
|
|
await set_node_webhook(node.id, None)
|
|
|
|
# Prune the webhook
|
|
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[credentials.provider]()
|
|
success = await webhook_manager.prune_webhook_if_dangling(
|
|
webhook.id, credentials
|
|
)
|
|
if not success:
|
|
logger.warning(f"Webhook #{webhook.id} failed to prune")
|
|
|
|
|
|
def _get_provider_oauth_handler(
|
|
req: Request, provider_name: ProviderName
|
|
) -> "BaseOAuthHandler":
|
|
if provider_name not in HANDLERS_BY_NAME:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"Provider '{provider_name.value}' does not support OAuth",
|
|
)
|
|
|
|
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
|
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
|
|
if not (client_id and client_secret):
|
|
raise HTTPException(
|
|
status_code=501,
|
|
detail=(
|
|
f"Integration with provider '{provider_name.value}' is not configured"
|
|
),
|
|
)
|
|
|
|
handler_class = HANDLERS_BY_NAME[provider_name]
|
|
frontend_base_url = (
|
|
settings.config.frontend_base_url
|
|
or settings.config.platform_base_url
|
|
or str(req.base_url)
|
|
)
|
|
return handler_class(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
|
)
|