feat: add web app auth

This commit is contained in:
GareArc 2025-04-07 17:03:26 -04:00
parent 1045f6db7a
commit 46d43e6758
7 changed files with 311 additions and 32 deletions

View File

@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
class WebSSOAuthRequiredError(BaseHTTPException):
error_code = "web_sso_auth_required"
description = "Web SSO authentication required."
class WebAppAuthRequiredError(BaseHTTPException):
error_code = "web_auth_required"
description = "Web app authentication required."
code = 401
class WebAppAuthFailedError(BaseHTTPException):
error_code = "web_app_auth_failed"
description = "You do not have permission to access this web app."
code = 401

View File

@ -0,0 +1,118 @@
from typing import cast
import flask_login
from flask import request
from flask_restful import Resource, reqparse
from jwt import InvalidTokenError # type: ignore
from web import api
import services
from controllers.console.auth.error import (EmailCodeError,
EmailOrPasswordMismatchError,
InvalidEmailError)
from controllers.console.error import AccountBannedError, AccountNotFound
from controllers.console.wraps import setup_required
from libs.helper import email
from libs.password import valid_password
from models.account import Account
from services.account_service import AccountService
from services.webapp_auth_service import Unauthorized, WebAppAuthService
class LoginApi(Resource):
"""Resource for web app email/password login."""
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
try:
account = WebAppAuthService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
raise AccountNotFound()
token = WebAppAuthService.login(account=account, app_code=app_code)
return {"result": "success", "token": token}
class LogoutApi(Resource):
@setup_required
def get(self):
account = cast(Account, flask_login.current_user)
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
flask_login.logout_user()
return {"result": "success"}
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = WebAppAuthService.get_user_through_email(args["email"])
if account is None:
raise AccountNotFound()
else:
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
app_code = request.headers.get("X-App-Code")
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
if not account:
raise AccountNotFound()
token = WebAppAuthService.login(account=account, app_code=app_code)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "token": token}
api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

View File

@ -5,7 +5,7 @@ from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api
from controllers.web.error import WebSSOAuthRequiredError
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
@ -22,10 +22,10 @@ class PassportResource(Resource):
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
raise WebSSOAuthRequiredError()
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.get_web_app_settings(app_code=app_code)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()

View File

@ -4,7 +4,8 @@ from flask import request
from flask_restful import Resource # type: ignore
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebSSOAuthRequiredError
from controllers.web.error import (WebAppAuthFailedError,
WebAppAuthRequiredError)
from extensions.ext_database import db
from libs.passport import PassportService
from models.model import App, EndUser, Site
@ -57,35 +58,48 @@ def decode_jwt_token():
if not end_user:
raise NotFound()
_validate_web_sso_token(decoded, system_features, app_code)
# for enterprise webapp auth
app_web_auth_enabled = False
if system_features.webapp_auth.enabled:
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
return app_model, end_user
except Unauthorized as e:
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
raise WebSSOAuthRequiredError()
if system_features.webapp_auth.enabled:
app_web_auth_enabled = EnterpriseService.get_web_app_settings(app_code=app_code).get("access_mode", "private") == "private"
if app_web_auth_enabled:
raise WebAppAuthRequiredError()
raise Unauthorized(e.description)
def _validate_web_sso_token(decoded, system_features, app_code):
app_web_sso_enabled = False
# Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
if system_features.sso_enforced_for_web:
app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False)
if app_web_sso_enabled:
source = decoded.get("token_source")
if not source or source != "sso":
raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
# Check if authentication is enforced for web app, and if the token source is not webapp, raise an error and redirect to login
if system_webapp_auth_enabled and app_web_auth_enabled:
source = decoded.get("token_source")
if source and source == "sso":
raise Unauthorized("sso token expired.")
if not source or source != "webapp":
raise WebAppAuthRequiredError()
# Check if authentication is not enforced for web, and if the token source is webapp,
# raise an error and redirect to normal passport login
if not system_webapp_auth_enabled or not app_web_auth_enabled:
source = decoded.get("token_source")
if source and source == "webapp":
raise Unauthorized("webapp token expired.")
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
if system_webapp_auth_enabled and app_web_auth_enabled:
# Check if the user is allowed to access the web app
user_id = decoded.get("user_id")
if not user_id:
raise WebAppAuthRequiredError()
if not EnterpriseService.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
raise WebAppAuthFailedError()
class WebApiResource(Resource):

View File

@ -1,11 +1,35 @@
from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest
class WebAppSettings(BaseModel):
access_mode: str = Field(
description="Access mode for the web app. Can be 'public' or 'private'",
default="private",
alias="access_mode",
)
class EnterpriseService:
@classmethod
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")
@classmethod
def get_app_web_sso_enabled(cls, app_code):
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id=None, app_code=None) -> bool:
if not app_id and not app_code:
raise ValueError("Either app_id or app_code must be provided.")
return EnterpriseRequest.send_request(
"GET", f"/web-app/allowed?appId={app_id}&appCode={app_code}&userId={user_id}"
)
@classmethod
def get_web_app_settings(cls, app_code: str = None, app_id: str = None):
if not app_code and not app_id:
raise ValueError("Either app_code or app_id must be provided.")
data = EnterpriseRequest.send_request("GET", f"/web-app/settings?appCode={app_code}&appId={app_id}")
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)

View File

@ -44,6 +44,13 @@ class BrandingModel(BaseModel):
favicon: str = ""
class WebAppAuthModel(BaseModel):
enabled: bool = False
allow_sso: bool = False
allow_email_code_login: bool = False
allow_email_password_login: bool = False
class FeatureModel(BaseModel):
billing: BillingModel = BillingModel()
members: LimitationModel = LimitationModel(size=0, limit=1)
@ -75,6 +82,7 @@ class SystemFeatureModel(BaseModel):
is_email_setup: bool = False
license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
class FeatureService:
@ -101,6 +109,7 @@ class FeatureService:
if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
cls._fulfill_params_from_enterprise(system_features)
return system_features
@ -194,6 +203,11 @@ class FeatureService:
features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "")
features.branding.favicon = enterprise_info["Branding"].get("favicon", "")
if "WebAppAuth" in enterprise_info:
features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSSO", False)
features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get("allowEmailCodeLogin", False)
features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get("allowEmailPasswordLogin", False)
if "License" in enterprise_info:
license_info = enterprise_info["License"]

View File

@ -0,0 +1,103 @@
import random
from datetime import UTC, datetime, timedelta
from typing import Any, Optional, cast
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from extensions.ext_database import db
from libs.helper import TokenManager
from libs.passport import PassportService
from libs.password import compare_password
from models.account import Account, AccountStatus
from models.model import Site
from services.errors.account import (AccountLoginError, AccountNotFoundError,
AccountPasswordError)
from tasks.mail_email_code_login import send_email_code_login_mail_task
class WebAppAuthService:
"""Service for web app authentication."""
@staticmethod
def authenticate(email: str, password: str) -> Account:
"""authenticate account with email and password"""
account = Account.query.filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
return cast(Account, account)
@staticmethod
def login(account: Account, app_code: str) -> str:
site = db.session.query(Site).filter(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
access_token = WebAppAuthService._get_account_jwt_token(account=account, site=site)
return access_token
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.")
return account
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
email = account.email if account else email
if email is None:
raise ValueError("Email must be provided.")
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "webapp_email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "webapp_email_code_login")
@staticmethod
def _get_account_jwt_token(account: Account, site: Site) -> str:
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.WebAppSessionTimeoutInHours * 24)
exp = int(exp_dt.timestamp())
payload = {
"iss": site.id,
"sub": "Web API Passport",
"app_id": site.app_id,
"app_code": site.code,
"user_id": account.id,
"token_source": "webapp",
"exp": exp,
}
token: str = PassportService().issue(payload)
return token