diff --git a/api/app.py b/api/app.py index ad91b5636f..124306b010 100644 --- a/api/app.py +++ b/api/app.py @@ -115,7 +115,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint == 'console': + if request.blueprint in ['console', 'inner_api']: # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get('Authorization', '') if not auth_header: @@ -153,6 +153,7 @@ def register_blueprints(app): from controllers.files import bp as files_bp from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp + from controllers.inner_api import bp as inner_api_bp CORS(service_api_bp, allow_headers=['Content-Type', 'Authorization', 'X-App-Code'], @@ -188,6 +189,8 @@ def register_blueprints(app): ) app.register_blueprint(files_bp) + app.register_blueprint(inner_api_bp) + # create app app = create_app() diff --git a/api/config.py b/api/config.py index f210ac48f9..631be4bbb5 100644 --- a/api/config.py +++ b/api/config.py @@ -69,6 +69,8 @@ DEFAULTS = { 'TOOL_ICON_CACHE_MAX_AGE': 3600, 'MILVUS_DATABASE': 'default', 'KEYWORD_DATA_SOURCE_TYPE': 'database', + 'INNER_API': 'False', + 'ENTERPRISE_ENABLED': 'False', } @@ -133,6 +135,11 @@ class Config: # Alternatively you can set it with `SECRET_KEY` environment variable. self.SECRET_KEY = get_env('SECRET_KEY') + # Enable or disable the inner API. + self.INNER_API = get_bool_env('INNER_API') + # The inner API key is used to authenticate the inner API. + self.INNER_API_KEY = get_env('INNER_API_KEY') + # cors settings self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) @@ -327,6 +334,8 @@ class Config: self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') + self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED') + class CloudEditionConfig(Config): diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 6cee7314e2..2895dbe73e 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -19,4 +19,6 @@ from .datasets import data_source, datasets, datasets_document, datasets_segment from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message, workflow) # Import workspace controllers -from .workspace import account, members, model_providers, models, tool_providers, workspace \ No newline at end of file +from .workspace import account, members, model_providers, models, tool_providers, workspace +# Import enterprise controllers +from .enterprise import enterprise_sso diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index d8cea95f48..8a24e58413 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -26,10 +26,13 @@ class LoginApi(Resource): try: account = AccountService.authenticate(args['email'], args['password']) - except services.errors.account.AccountLoginError: - return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 + except services.errors.account.AccountLoginError as e: + return {'code': 'unauthorized', 'message': str(e)}, 401 - TenantService.create_owner_tenant_if_not_exist(account) + # SELF_HOSTED only have one workspace + tenants = TenantService.get_join_tenants(account) + if len(tenants) == 0: + return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} AccountService.update_last_login(account, request) diff --git a/api/controllers/console/enterprise/__init__.py b/api/controllers/console/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/console/enterprise/enterprise_sso.py b/api/controllers/console/enterprise/enterprise_sso.py new file mode 100644 index 0000000000..f6a2897d5a --- /dev/null +++ b/api/controllers/console/enterprise/enterprise_sso.py @@ -0,0 +1,59 @@ +from flask import current_app, redirect +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.setup import setup_required +from services.enterprise.enterprise_sso_service import EnterpriseSSOService + + +class EnterpriseSSOSamlLogin(Resource): + + @setup_required + def get(self): + return EnterpriseSSOService.get_sso_saml_login() + + +class EnterpriseSSOSamlAcs(Resource): + + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('SAMLResponse', type=str, required=True, location='form') + args = parser.parse_args() + saml_response = args['SAMLResponse'] + + try: + token = EnterpriseSSOService.post_sso_saml_acs(saml_response) + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') + except Exception as e: + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') + + +class EnterpriseSSOOidcLogin(Resource): + + @setup_required + def get(self): + return EnterpriseSSOService.get_sso_oidc_login() + + +class EnterpriseSSOOidcCallback(Resource): + + @setup_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('state', type=str, required=True, location='args') + parser.add_argument('code', type=str, required=True, location='args') + parser.add_argument('oidc-state', type=str, required=True, location='cookies') + args = parser.parse_args() + + try: + token = EnterpriseSSOService.get_sso_oidc_callback(args) + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}') + except Exception as e: + return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}') + + +api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login') +api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs') +api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login') +api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback') diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 824549050f..325652a447 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,6 +1,7 @@ from flask_login import current_user from flask_restful import Resource +from services.enterprise.enterprise_feature_service import EnterpriseFeatureService from services.feature_service import FeatureService from . import api @@ -14,4 +15,10 @@ class FeatureApi(Resource): return FeatureService.get_features(current_user.current_tenant_id).dict() +class EnterpriseFeatureApi(Resource): + def get(self): + return EnterpriseFeatureService.get_enterprise_features().dict() + + api.add_resource(FeatureApi, '/features') +api.add_resource(EnterpriseFeatureApi, '/enterprise-features') diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index a8d0dd4344..1911559cff 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -58,6 +58,8 @@ class SetupApi(Resource): password=args['password'] ) + TenantService.create_owner_tenant_if_not_exist(account) + setup() AccountService.update_last_login(account, request) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7b3f08f467..cd72872b62 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -3,6 +3,7 @@ import logging from flask import request from flask_login import current_user from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from werkzeug.exceptions import Unauthorized import services from controllers.console import api @@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant +from models.account import Tenant, TenantStatus from services.account_service import TenantService from services.file_service import FileService from services.workspace_service import WorkspaceService @@ -116,6 +117,16 @@ class TenantApi(Resource): tenant = current_user.current_tenant + if tenant.status == TenantStatus.ARCHIVE: + tenants = TenantService.get_join_tenants(current_user) + # if there is any tenant, switch to the first one + if len(tenants) > 0: + TenantService.switch_tenant(current_user, tenants[0].id) + tenant = tenants[0] + # else, raise Unauthorized + else: + raise Unauthorized('workspace is archived') + return WorkspaceService.get_tenant_info(tenant), 200 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py new file mode 100644 index 0000000000..067c28c3fa --- /dev/null +++ b/api/controllers/inner_api/__init__.py @@ -0,0 +1,8 @@ +from flask import Blueprint +from libs.external_api import ExternalApi + +bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +api = ExternalApi(bp) + +from .workspace import workspace + diff --git a/api/controllers/inner_api/workspace/__init__.py b/api/controllers/inner_api/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py new file mode 100644 index 0000000000..06610d8933 --- /dev/null +++ b/api/controllers/inner_api/workspace/workspace.py @@ -0,0 +1,37 @@ +from flask_restful import Resource, reqparse + +from controllers.console.setup import setup_required +from controllers.inner_api import api +from controllers.inner_api.wraps import inner_api_only +from events.tenant_event import tenant_was_created +from models.account import Account +from services.account_service import TenantService + + +class EnterpriseWorkspace(Resource): + + @setup_required + @inner_api_only + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('owner_email', type=str, required=True, location='json') + args = parser.parse_args() + + account = Account.query.filter_by(email=args['owner_email']).first() + if account is None: + return { + 'message': 'owner account not found.' + }, 404 + + tenant = TenantService.create_tenant(args['name']) + TenantService.create_tenant_member(tenant, account, role='owner') + + tenant_was_created.send(tenant) + + return { + 'message': 'enterprise workspace created.' + } + + +api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py new file mode 100644 index 0000000000..07cd38bc85 --- /dev/null +++ b/api/controllers/inner_api/wraps.py @@ -0,0 +1,61 @@ +from base64 import b64encode +from functools import wraps +from hashlib import sha1 +from hmac import new as hmac_new + +from flask import abort, current_app, request + +from extensions.ext_database import db +from models.model import EndUser + + +def inner_api_only(view): + @wraps(view) + def decorated(*args, **kwargs): + if not current_app.config['INNER_API']: + abort(404) + + # get header 'X-Inner-Api-Key' + inner_api_key = request.headers.get('X-Inner-Api-Key') + if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']: + abort(404) + + return view(*args, **kwargs) + + return decorated + + +def inner_api_user_auth(view): + @wraps(view) + def decorated(*args, **kwargs): + if not current_app.config['INNER_API']: + return view(*args, **kwargs) + + # get header 'X-Inner-Api-Key' + authorization = request.headers.get('Authorization') + if not authorization: + return view(*args, **kwargs) + + parts = authorization.split(':') + if len(parts) != 2: + return view(*args, **kwargs) + + user_id, token = parts + if ' ' in user_id: + user_id = user_id.split(' ')[1] + + inner_api_key = request.headers.get('X-Inner-Api-Key') + + data_to_sign = f'DIFY {user_id}' + + signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) + signature = b64encode(signature.digest()).decode('utf-8') + + if signature != token: + return view(*args, **kwargs) + + kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 70733d63f4..8ae81531ae 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from libs.login import _get_user -from models.account import Account, Tenant, TenantAccountJoin +from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService @@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if not app_model.enable_api: raise NotFound() + tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant.status == TenantStatus.ARCHIVE: + raise NotFound() + kwargs['app_model'] = app_model if fetch_user_arg: @@ -137,6 +141,7 @@ def validate_dataset_token(view=None): .filter(Tenant.id == api_token.tenant_id) \ .filter(TenantAccountJoin.tenant_id == Tenant.id) \ .filter(TenantAccountJoin.role.in_(['owner'])) \ + .filter(Tenant.status == TenantStatus.NORMAL) \ .one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index bf3536d276..49b0a8bfc0 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from models.account import TenantStatus from models.model import Site from services.feature_service import FeatureService @@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource): if not site: raise Forbidden() + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) diff --git a/api/models/account.py b/api/models/account.py index 11aa1c996d..7854e3f63e 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -105,6 +105,12 @@ class Account(UserMixin, db.Model): def is_admin_or_owner(self): return self._current_tenant.current_role in ['admin', 'owner'] + +class TenantStatus(str, enum.Enum): + NORMAL = 'normal' + ARCHIVE = 'archive' + + class Tenant(db.Model): __tablename__ = 'tenants' __table_args__ = ( diff --git a/api/services/account_service.py b/api/services/account_service.py index 1fe8da760c..64fe3a4f0f 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -8,7 +8,7 @@ from typing import Any, Optional from flask import current_app from sqlalchemy import func -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import Unauthorized from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created @@ -44,7 +44,7 @@ class AccountService: return None if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: - raise Forbidden('Account is banned or closed.') + raise Unauthorized("Account is banned or closed.") current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: @@ -255,7 +255,7 @@ class TenantService: """Get account join tenants""" return db.session.query(Tenant).join( TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id).all() + ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() @staticmethod def get_current_tenant_by_account(account: Account): @@ -279,7 +279,12 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() + tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ).first() + if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: diff --git a/api/services/enterprise/__init__.py b/api/services/enterprise/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py new file mode 100644 index 0000000000..c483d28152 --- /dev/null +++ b/api/services/enterprise/base.py @@ -0,0 +1,20 @@ +import os + +import requests + + +class EnterpriseRequest: + base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') + secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + + @classmethod + def send_request(cls, method, endpoint, json=None, params=None): + headers = { + "Content-Type": "application/json", + "Enterprise-Api-Secret-Key": cls.secret_key + } + + url = f"{cls.base_url}{endpoint}" + response = requests.request(method, url, json=json, params=params, headers=headers) + + return response.json() diff --git a/api/services/enterprise/enterprise_feature_service.py b/api/services/enterprise/enterprise_feature_service.py new file mode 100644 index 0000000000..fe33349aa8 --- /dev/null +++ b/api/services/enterprise/enterprise_feature_service.py @@ -0,0 +1,28 @@ +from flask import current_app +from pydantic import BaseModel + +from services.enterprise.enterprise_service import EnterpriseService + + +class EnterpriseFeatureModel(BaseModel): + sso_enforced_for_signin: bool = False + sso_enforced_for_signin_protocol: str = '' + + +class EnterpriseFeatureService: + + @classmethod + def get_enterprise_features(cls) -> EnterpriseFeatureModel: + features = EnterpriseFeatureModel() + + if current_app.config['ENTERPRISE_ENABLED']: + cls._fulfill_params_from_enterprise(features) + + return features + + @classmethod + def _fulfill_params_from_enterprise(cls, features): + enterprise_info = EnterpriseService.get_info() + + features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] + features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py new file mode 100644 index 0000000000..115d0d5523 --- /dev/null +++ b/api/services/enterprise/enterprise_service.py @@ -0,0 +1,8 @@ +from services.enterprise.base import EnterpriseRequest + + +class EnterpriseService: + + @classmethod + def get_info(cls): + return EnterpriseRequest.send_request('GET', '/info') diff --git a/api/services/enterprise/enterprise_sso_service.py b/api/services/enterprise/enterprise_sso_service.py new file mode 100644 index 0000000000..d8e19f23bf --- /dev/null +++ b/api/services/enterprise/enterprise_sso_service.py @@ -0,0 +1,60 @@ +import logging + +from models.account import Account, AccountStatus +from services.account_service import AccountService, TenantService +from services.enterprise.base import EnterpriseRequest + +logger = logging.getLogger(__name__) + + +class EnterpriseSSOService: + + @classmethod + def get_sso_saml_login(cls) -> str: + return EnterpriseRequest.send_request('GET', '/sso/saml/login') + + @classmethod + def post_sso_saml_acs(cls, saml_response: str) -> str: + response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response}) + if 'email' not in response or response['email'] is None: + logger.exception(response) + raise Exception('Saml response is invalid') + + return cls.login_with_email(response.get('email')) + + @classmethod + def get_sso_oidc_login(cls): + return EnterpriseRequest.send_request('GET', '/sso/oidc/login') + + @classmethod + def get_sso_oidc_callback(cls, args: dict): + state_from_query = args['state'] + code_from_query = args['code'] + state_from_cookies = args['oidc-state'] + + if state_from_cookies != state_from_query: + raise Exception('invalid state or code') + + response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query}) + if 'email' not in response or response['email'] is None: + logger.exception(response) + raise Exception('OIDC response is invalid') + + return cls.login_with_email(response.get('email')) + + @classmethod + def login_with_email(cls, email: str) -> str: + account = Account.query.filter_by(email=email).first() + if account is None: + raise Exception('account not found, please contact system admin to invite you to join in a workspace') + + if account.status == AccountStatus.BANNED: + raise Exception('account is banned, please contact system admin') + + tenants = TenantService.get_join_tenants(account) + if len(tenants) == 0: + raise Exception("workspace not found, please contact system admin to invite you to join in a workspace") + + token = AccountService.get_account_jwt_token(account) + + return token diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 720260a307..ba9f9f32c6 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -39,6 +39,10 @@ export default function AppSelector({ isMobile }: IAppSelecotr) { url: '/logout', params: {}, }) + + if (localStorage?.getItem('console_token')) + localStorage.removeItem('console_token') + router.push('/signin') } diff --git a/web/app/signin/_header.tsx b/web/app/signin/_header.tsx index 7180a66817..a9479a3fe4 100644 --- a/web/app/signin/_header.tsx +++ b/web/app/signin/_header.tsx @@ -10,9 +10,6 @@ import LogoSite from '@/app/components/base/logo/logo-site' const Header = () => { const { locale, setLocaleOnClient } = useContext(I18n) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') - return