mirror of https://github.com/langgenius/dify.git
Feat/enterprise sso (#3602)
This commit is contained in:
parent
d9f1a8ce9f
commit
4481906be2
|
@ -115,7 +115,7 @@ def initialize_extensions(app):
|
||||||
@login_manager.request_loader
|
@login_manager.request_loader
|
||||||
def load_user_from_request(request_from_flask_login):
|
def load_user_from_request(request_from_flask_login):
|
||||||
"""Load user based on the request."""
|
"""Load user based on the request."""
|
||||||
if request.blueprint == 'console':
|
if request.blueprint in ['console', 'inner_api']:
|
||||||
# Check if the user_id contains a dot, indicating the old format
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
auth_header = request.headers.get('Authorization', '')
|
auth_header = request.headers.get('Authorization', '')
|
||||||
if not auth_header:
|
if not auth_header:
|
||||||
|
@ -153,6 +153,7 @@ def register_blueprints(app):
|
||||||
from controllers.files import bp as files_bp
|
from controllers.files import bp as files_bp
|
||||||
from controllers.service_api import bp as service_api_bp
|
from controllers.service_api import bp as service_api_bp
|
||||||
from controllers.web import bp as web_bp
|
from controllers.web import bp as web_bp
|
||||||
|
from controllers.inner_api import bp as inner_api_bp
|
||||||
|
|
||||||
CORS(service_api_bp,
|
CORS(service_api_bp,
|
||||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||||
|
@ -188,6 +189,8 @@ def register_blueprints(app):
|
||||||
)
|
)
|
||||||
app.register_blueprint(files_bp)
|
app.register_blueprint(files_bp)
|
||||||
|
|
||||||
|
app.register_blueprint(inner_api_bp)
|
||||||
|
|
||||||
|
|
||||||
# create app
|
# create app
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
|
@ -69,6 +69,8 @@ DEFAULTS = {
|
||||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||||
'MILVUS_DATABASE': 'default',
|
'MILVUS_DATABASE': 'default',
|
||||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||||
|
'INNER_API': 'False',
|
||||||
|
'ENTERPRISE_ENABLED': 'False',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,6 +135,11 @@ class Config:
|
||||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||||
|
|
||||||
|
# Enable or disable the inner API.
|
||||||
|
self.INNER_API = get_bool_env('INNER_API')
|
||||||
|
# The inner API key is used to authenticate the inner API.
|
||||||
|
self.INNER_API_KEY = get_env('INNER_API_KEY')
|
||||||
|
|
||||||
# cors settings
|
# cors settings
|
||||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||||
|
@ -327,6 +334,8 @@ class Config:
|
||||||
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
||||||
|
|
||||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||||
|
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
|
||||||
|
|
||||||
|
|
||||||
class CloudEditionConfig(Config):
|
class CloudEditionConfig(Config):
|
||||||
|
|
||||||
|
|
|
@ -19,4 +19,6 @@ from .datasets import data_source, datasets, datasets_document, datasets_segment
|
||||||
from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app,
|
from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app,
|
||||||
saved_message, workflow)
|
saved_message, workflow)
|
||||||
# Import workspace controllers
|
# Import workspace controllers
|
||||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||||
|
# Import enterprise controllers
|
||||||
|
from .enterprise import enterprise_sso
|
||||||
|
|
|
@ -26,10 +26,13 @@ class LoginApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
account = AccountService.authenticate(args['email'], args['password'])
|
account = AccountService.authenticate(args['email'], args['password'])
|
||||||
except services.errors.account.AccountLoginError:
|
except services.errors.account.AccountLoginError as e:
|
||||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
return {'code': 'unauthorized', 'message': str(e)}, 401
|
||||||
|
|
||||||
TenantService.create_owner_tenant_if_not_exist(account)
|
# SELF_HOSTED only have one workspace
|
||||||
|
tenants = TenantService.get_join_tenants(account)
|
||||||
|
if len(tenants) == 0:
|
||||||
|
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
|
||||||
|
|
||||||
AccountService.update_last_login(account, request)
|
AccountService.update_last_login(account, request)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
from flask import current_app, redirect
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from services.enterprise.enterprise_sso_service import EnterpriseSSOService
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseSSOSamlLogin(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def get(self):
|
||||||
|
return EnterpriseSSOService.get_sso_saml_login()
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseSSOSamlAcs(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('SAMLResponse', type=str, required=True, location='form')
|
||||||
|
args = parser.parse_args()
|
||||||
|
saml_response = args['SAMLResponse']
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
|
||||||
|
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||||
|
except Exception as e:
|
||||||
|
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseSSOOidcLogin(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def get(self):
|
||||||
|
return EnterpriseSSOService.get_sso_oidc_login()
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseSSOOidcCallback(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('state', type=str, required=True, location='args')
|
||||||
|
parser.add_argument('code', type=str, required=True, location='args')
|
||||||
|
parser.add_argument('oidc-state', type=str, required=True, location='cookies')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = EnterpriseSSOService.get_sso_oidc_callback(args)
|
||||||
|
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||||
|
except Exception as e:
|
||||||
|
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
|
||||||
|
api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
|
||||||
|
api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
|
||||||
|
api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')
|
|
@ -1,6 +1,7 @@
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource
|
from flask_restful import Resource
|
||||||
|
|
||||||
|
from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
|
@ -14,4 +15,10 @@ class FeatureApi(Resource):
|
||||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseFeatureApi(Resource):
|
||||||
|
def get(self):
|
||||||
|
return EnterpriseFeatureService.get_enterprise_features().dict()
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FeatureApi, '/features')
|
api.add_resource(FeatureApi, '/features')
|
||||||
|
api.add_resource(EnterpriseFeatureApi, '/enterprise-features')
|
||||||
|
|
|
@ -58,6 +58,8 @@ class SetupApi(Resource):
|
||||||
password=args['password']
|
password=args['password']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TenantService.create_owner_tenant_if_not_exist(account)
|
||||||
|
|
||||||
setup()
|
setup()
|
||||||
AccountService.update_last_login(account, request)
|
AccountService.update_last_login(account, request)
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
|
@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.account import Tenant
|
from models.account import Tenant, TenantStatus
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
from services.workspace_service import WorkspaceService
|
from services.workspace_service import WorkspaceService
|
||||||
|
@ -116,6 +117,16 @@ class TenantApi(Resource):
|
||||||
|
|
||||||
tenant = current_user.current_tenant
|
tenant = current_user.current_tenant
|
||||||
|
|
||||||
|
if tenant.status == TenantStatus.ARCHIVE:
|
||||||
|
tenants = TenantService.get_join_tenants(current_user)
|
||||||
|
# if there is any tenant, switch to the first one
|
||||||
|
if len(tenants) > 0:
|
||||||
|
TenantService.switch_tenant(current_user, tenants[0].id)
|
||||||
|
tenant = tenants[0]
|
||||||
|
# else, raise Unauthorized
|
||||||
|
else:
|
||||||
|
raise Unauthorized('workspace is archived')
|
||||||
|
|
||||||
return WorkspaceService.get_tenant_info(tenant), 200
|
return WorkspaceService.get_tenant_info(tenant), 200
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
from flask import Blueprint
|
||||||
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
|
bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
|
||||||
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
from .workspace import workspace
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.inner_api import api
|
||||||
|
from controllers.inner_api.wraps import inner_api_only
|
||||||
|
from events.tenant_event import tenant_was_created
|
||||||
|
from models.account import Account
|
||||||
|
from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseWorkspace(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@inner_api_only
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('name', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('owner_email', type=str, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
account = Account.query.filter_by(email=args['owner_email']).first()
|
||||||
|
if account is None:
|
||||||
|
return {
|
||||||
|
'message': 'owner account not found.'
|
||||||
|
}, 404
|
||||||
|
|
||||||
|
tenant = TenantService.create_tenant(args['name'])
|
||||||
|
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||||
|
|
||||||
|
tenant_was_created.send(tenant)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'message': 'enterprise workspace created.'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')
|
|
@ -0,0 +1,61 @@
|
||||||
|
from base64 import b64encode
|
||||||
|
from functools import wraps
|
||||||
|
from hashlib import sha1
|
||||||
|
from hmac import new as hmac_new
|
||||||
|
|
||||||
|
from flask import abort, current_app, request
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import EndUser
|
||||||
|
|
||||||
|
|
||||||
|
def inner_api_only(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if not current_app.config['INNER_API']:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
# get header 'X-Inner-Api-Key'
|
||||||
|
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||||
|
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||||
|
abort(404)
|
||||||
|
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def inner_api_user_auth(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if not current_app.config['INNER_API']:
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
# get header 'X-Inner-Api-Key'
|
||||||
|
authorization = request.headers.get('Authorization')
|
||||||
|
if not authorization:
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
parts = authorization.split(':')
|
||||||
|
if len(parts) != 2:
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
user_id, token = parts
|
||||||
|
if ' ' in user_id:
|
||||||
|
user_id = user_id.split(' ')[1]
|
||||||
|
|
||||||
|
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||||
|
|
||||||
|
data_to_sign = f'DIFY {user_id}'
|
||||||
|
|
||||||
|
signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
|
||||||
|
signature = b64encode(signature.digest()).decode('utf-8')
|
||||||
|
|
||||||
|
if signature != token:
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||||
|
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
|
@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import _get_user
|
from libs.login import _get_user
|
||||||
from models.account import Account, Tenant, TenantAccountJoin
|
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||||
from models.model import ApiToken, App, EndUser
|
from models.model import ApiToken, App, EndUser
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||||
if not app_model.enable_api:
|
if not app_model.enable_api:
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
|
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
|
||||||
|
if tenant.status == TenantStatus.ARCHIVE:
|
||||||
|
raise NotFound()
|
||||||
|
|
||||||
kwargs['app_model'] = app_model
|
kwargs['app_model'] = app_model
|
||||||
|
|
||||||
if fetch_user_arg:
|
if fetch_user_arg:
|
||||||
|
@ -137,6 +141,7 @@ def validate_dataset_token(view=None):
|
||||||
.filter(Tenant.id == api_token.tenant_id) \
|
.filter(Tenant.id == api_token.tenant_id) \
|
||||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
||||||
|
.filter(Tenant.status == TenantStatus.NORMAL) \
|
||||||
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||||
if tenant_account_join:
|
if tenant_account_join:
|
||||||
tenant, ta = tenant_account_join
|
tenant, ta = tenant_account_join
|
||||||
|
|
|
@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.account import TenantStatus
|
||||||
from models.model import Site
|
from models.model import Site
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource):
|
||||||
if not site:
|
if not site:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
|
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||||
|
|
||||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||||
|
|
|
@ -105,6 +105,12 @@ class Account(UserMixin, db.Model):
|
||||||
def is_admin_or_owner(self):
|
def is_admin_or_owner(self):
|
||||||
return self._current_tenant.current_role in ['admin', 'owner']
|
return self._current_tenant.current_role in ['admin', 'owner']
|
||||||
|
|
||||||
|
|
||||||
|
class TenantStatus(str, enum.Enum):
|
||||||
|
NORMAL = 'normal'
|
||||||
|
ARCHIVE = 'archive'
|
||||||
|
|
||||||
|
|
||||||
class Tenant(db.Model):
|
class Tenant(db.Model):
|
||||||
__tablename__ = 'tenants'
|
__tablename__ = 'tenants'
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
from constants.languages import language_timezone_mapping, languages
|
from constants.languages import language_timezone_mapping, languages
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
|
@ -44,7 +44,7 @@ class AccountService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
|
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
|
||||||
raise Forbidden('Account is banned or closed.')
|
raise Unauthorized("Account is banned or closed.")
|
||||||
|
|
||||||
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||||
if current_tenant:
|
if current_tenant:
|
||||||
|
@ -255,7 +255,7 @@ class TenantService:
|
||||||
"""Get account join tenants"""
|
"""Get account join tenants"""
|
||||||
return db.session.query(Tenant).join(
|
return db.session.query(Tenant).join(
|
||||||
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
|
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
|
||||||
).filter(TenantAccountJoin.account_id == account.id).all()
|
).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_current_tenant_by_account(account: Account):
|
def get_current_tenant_by_account(account: Account):
|
||||||
|
@ -279,7 +279,12 @@ class TenantService:
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
raise ValueError("Tenant ID must be provided.")
|
raise ValueError("Tenant ID must be provided.")
|
||||||
|
|
||||||
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
|
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
|
||||||
|
TenantAccountJoin.account_id == account.id,
|
||||||
|
TenantAccountJoin.tenant_id == tenant_id,
|
||||||
|
Tenant.status == TenantStatus.NORMAL,
|
||||||
|
).first()
|
||||||
|
|
||||||
if not tenant_account_join:
|
if not tenant_account_join:
|
||||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseRequest:
|
||||||
|
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
|
||||||
|
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def send_request(cls, method, endpoint, json=None, params=None):
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Enterprise-Api-Secret-Key": cls.secret_key
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{cls.base_url}{endpoint}"
|
||||||
|
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||||
|
|
||||||
|
return response.json()
|
|
@ -0,0 +1,28 @@
|
||||||
|
from flask import current_app
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseFeatureModel(BaseModel):
|
||||||
|
sso_enforced_for_signin: bool = False
|
||||||
|
sso_enforced_for_signin_protocol: str = ''
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseFeatureService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_enterprise_features(cls) -> EnterpriseFeatureModel:
|
||||||
|
features = EnterpriseFeatureModel()
|
||||||
|
|
||||||
|
if current_app.config['ENTERPRISE_ENABLED']:
|
||||||
|
cls._fulfill_params_from_enterprise(features)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fulfill_params_from_enterprise(cls, features):
|
||||||
|
enterprise_info = EnterpriseService.get_info()
|
||||||
|
|
||||||
|
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
|
||||||
|
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
|
|
@ -0,0 +1,8 @@
|
||||||
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_info(cls):
|
||||||
|
return EnterpriseRequest.send_request('GET', '/info')
|
|
@ -0,0 +1,60 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from models.account import Account, AccountStatus
|
||||||
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseSSOService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_sso_saml_login(cls) -> str:
|
||||||
|
return EnterpriseRequest.send_request('GET', '/sso/saml/login')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_sso_saml_acs(cls, saml_response: str) -> str:
|
||||||
|
response = EnterpriseRequest.send_request('POST', '/sso/saml/acs', json={'SAMLResponse': saml_response})
|
||||||
|
if 'email' not in response or response['email'] is None:
|
||||||
|
logger.exception(response)
|
||||||
|
raise Exception('Saml response is invalid')
|
||||||
|
|
||||||
|
return cls.login_with_email(response.get('email'))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_sso_oidc_login(cls):
|
||||||
|
return EnterpriseRequest.send_request('GET', '/sso/oidc/login')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_sso_oidc_callback(cls, args: dict):
|
||||||
|
state_from_query = args['state']
|
||||||
|
code_from_query = args['code']
|
||||||
|
state_from_cookies = args['oidc-state']
|
||||||
|
|
||||||
|
if state_from_cookies != state_from_query:
|
||||||
|
raise Exception('invalid state or code')
|
||||||
|
|
||||||
|
response = EnterpriseRequest.send_request('GET', '/sso/oidc/callback', params={'code': code_from_query})
|
||||||
|
if 'email' not in response or response['email'] is None:
|
||||||
|
logger.exception(response)
|
||||||
|
raise Exception('OIDC response is invalid')
|
||||||
|
|
||||||
|
return cls.login_with_email(response.get('email'))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def login_with_email(cls, email: str) -> str:
|
||||||
|
account = Account.query.filter_by(email=email).first()
|
||||||
|
if account is None:
|
||||||
|
raise Exception('account not found, please contact system admin to invite you to join in a workspace')
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED:
|
||||||
|
raise Exception('account is banned, please contact system admin')
|
||||||
|
|
||||||
|
tenants = TenantService.get_join_tenants(account)
|
||||||
|
if len(tenants) == 0:
|
||||||
|
raise Exception("workspace not found, please contact system admin to invite you to join in a workspace")
|
||||||
|
|
||||||
|
token = AccountService.get_account_jwt_token(account)
|
||||||
|
|
||||||
|
return token
|
|
@ -39,6 +39,10 @@ export default function AppSelector({ isMobile }: IAppSelecotr) {
|
||||||
url: '/logout',
|
url: '/logout',
|
||||||
params: {},
|
params: {},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if (localStorage?.getItem('console_token'))
|
||||||
|
localStorage.removeItem('console_token')
|
||||||
|
|
||||||
router.push('/signin')
|
router.push('/signin')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,6 @@ import LogoSite from '@/app/components/base/logo/logo-site'
|
||||||
const Header = () => {
|
const Header = () => {
|
||||||
const { locale, setLocaleOnClient } = useContext(I18n)
|
const { locale, setLocaleOnClient } = useContext(I18n)
|
||||||
|
|
||||||
if (localStorage?.getItem('console_token'))
|
|
||||||
localStorage.removeItem('console_token')
|
|
||||||
|
|
||||||
return <div className='flex items-center justify-between p-6 w-full'>
|
return <div className='flex items-center justify-between p-6 w-full'>
|
||||||
<LogoSite />
|
<LogoSite />
|
||||||
<Select
|
<Select
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
'use client'
|
||||||
|
import cn from 'classnames'
|
||||||
|
import { useRouter, useSearchParams } from 'next/navigation'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import Toast from '@/app/components/base/toast'
|
||||||
|
import { getOIDCSSOUrl, getSAMLSSOUrl } from '@/service/enterprise'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
|
||||||
|
type EnterpriseSSOFormProps = {
|
||||||
|
protocol: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const EnterpriseSSOForm: FC<EnterpriseSSOFormProps> = ({
|
||||||
|
protocol,
|
||||||
|
}) => {
|
||||||
|
const searchParams = useSearchParams()
|
||||||
|
const consoleToken = searchParams.get('console_token')
|
||||||
|
const message = searchParams.get('message')
|
||||||
|
|
||||||
|
const router = useRouter()
|
||||||
|
const { t } = useTranslation()
|
||||||
|
|
||||||
|
const [isLoading, setIsLoading] = useState(false)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (consoleToken) {
|
||||||
|
localStorage.setItem('console_token', consoleToken)
|
||||||
|
router.replace('/apps')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (message) {
|
||||||
|
Toast.notify({
|
||||||
|
type: 'error',
|
||||||
|
message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const handleSSOLogin = () => {
|
||||||
|
setIsLoading(true)
|
||||||
|
if (protocol === 'saml') {
|
||||||
|
getSAMLSSOUrl().then((res) => {
|
||||||
|
router.push(res.url)
|
||||||
|
}).finally(() => {
|
||||||
|
setIsLoading(false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
getOIDCSSOUrl().then((res) => {
|
||||||
|
document.cookie = `oidc-state=${res.state}`
|
||||||
|
router.push(res.url)
|
||||||
|
}).finally(() => {
|
||||||
|
setIsLoading(false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={
|
||||||
|
cn(
|
||||||
|
'flex flex-col items-center w-full grow items-center justify-center',
|
||||||
|
'px-6',
|
||||||
|
'md:px-[108px]',
|
||||||
|
)
|
||||||
|
}>
|
||||||
|
<div className='flex flex-col md:w-[400px]'>
|
||||||
|
<div className="w-full mx-auto">
|
||||||
|
<h2 className="text-[32px] font-bold text-gray-900">{t('login.pageTitle')}</h2>
|
||||||
|
</div>
|
||||||
|
<div className="w-full mx-auto mt-10">
|
||||||
|
<Button
|
||||||
|
tabIndex={0}
|
||||||
|
type='primary'
|
||||||
|
onClick={() => { handleSSOLogin() }}
|
||||||
|
disabled={isLoading}
|
||||||
|
className="w-full !fone-medium !text-sm"
|
||||||
|
>{t('login.sso')}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default EnterpriseSSOForm
|
|
@ -96,8 +96,17 @@ const NormalForm = () => {
|
||||||
remember_me: true,
|
remember_me: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
localStorage.setItem('console_token', res.data)
|
|
||||||
router.replace('/apps')
|
if (res.result === 'success') {
|
||||||
|
localStorage.setItem('console_token', res.data)
|
||||||
|
router.replace('/apps')
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Toast.notify({
|
||||||
|
type: 'error',
|
||||||
|
message: res.data,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
finally {
|
finally {
|
||||||
setIsLoading(false)
|
setIsLoading(false)
|
||||||
|
|
|
@ -1,12 +1,29 @@
|
||||||
import React from 'react'
|
'use client'
|
||||||
|
import React, { useEffect, useState } from 'react'
|
||||||
import cn from 'classnames'
|
import cn from 'classnames'
|
||||||
import Script from 'next/script'
|
import Script from 'next/script'
|
||||||
|
import Loading from '../components/base/loading'
|
||||||
import Forms from './forms'
|
import Forms from './forms'
|
||||||
import Header from './_header'
|
import Header from './_header'
|
||||||
import style from './page.module.css'
|
import style from './page.module.css'
|
||||||
|
import EnterpriseSSOForm from './enterpriseSSOForm'
|
||||||
import { IS_CE_EDITION } from '@/config'
|
import { IS_CE_EDITION } from '@/config'
|
||||||
|
import { getEnterpriseFeatures } from '@/service/enterprise'
|
||||||
|
import type { EnterpriseFeatures } from '@/types/enterprise'
|
||||||
|
import { defaultEnterpriseFeatures } from '@/types/enterprise'
|
||||||
|
|
||||||
const SignIn = () => {
|
const SignIn = () => {
|
||||||
|
const [loading, setLoading] = useState<boolean>(true)
|
||||||
|
const [enterpriseFeatures, setEnterpriseFeatures] = useState<EnterpriseFeatures>(defaultEnterpriseFeatures)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
getEnterpriseFeatures().then((res) => {
|
||||||
|
setEnterpriseFeatures(res)
|
||||||
|
}).finally(() => {
|
||||||
|
setLoading(false)
|
||||||
|
})
|
||||||
|
}, [])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{!IS_CE_EDITION && (
|
{!IS_CE_EDITION && (
|
||||||
|
@ -40,10 +57,31 @@ gtag('config', 'AW-11217955271"');
|
||||||
)
|
)
|
||||||
}>
|
}>
|
||||||
<Header />
|
<Header />
|
||||||
<Forms />
|
|
||||||
<div className='px-8 py-6 text-sm font-normal text-gray-500'>
|
{loading && (
|
||||||
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
|
<div className={
|
||||||
</div>
|
cn(
|
||||||
|
'flex flex-col items-center w-full grow items-center justify-center',
|
||||||
|
'px-6',
|
||||||
|
'md:px-[108px]',
|
||||||
|
)
|
||||||
|
}>
|
||||||
|
<Loading type='area' />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!loading && !enterpriseFeatures.sso_enforced_for_signin && (
|
||||||
|
<>
|
||||||
|
<Forms />
|
||||||
|
<div className='px-8 py-6 text-sm font-normal text-gray-500'>
|
||||||
|
© {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!loading && enterpriseFeatures.sso_enforced_for_signin && (
|
||||||
|
<EnterpriseSSOForm protocol={enterpriseFeatures.sso_enforced_for_signin_protocol} />
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -9,6 +9,7 @@ const translation = {
|
||||||
namePlaceholder: 'Your username',
|
namePlaceholder: 'Your username',
|
||||||
forget: 'Forgot your password?',
|
forget: 'Forgot your password?',
|
||||||
signBtn: 'Sign in',
|
signBtn: 'Sign in',
|
||||||
|
sso: 'Continue with SSO',
|
||||||
installBtn: 'Set up',
|
installBtn: 'Set up',
|
||||||
setAdminAccount: 'Setting up an admin account',
|
setAdminAccount: 'Setting up an admin account',
|
||||||
setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.',
|
setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.',
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
import { get } from './base'
|
||||||
|
import type { EnterpriseFeatures } from '@/types/enterprise'
|
||||||
|
|
||||||
|
export const getEnterpriseFeatures = () => {
|
||||||
|
return get<EnterpriseFeatures>('/enterprise-features')
|
||||||
|
}
|
||||||
|
|
||||||
|
export const getSAMLSSOUrl = () => {
|
||||||
|
return get<{ url: string }>('/enterprise/sso/saml/login')
|
||||||
|
}
|
||||||
|
|
||||||
|
export const getOIDCSSOUrl = () => {
|
||||||
|
return get<{ url: string; state: string }>('/enterprise/sso/oidc/login')
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
export type EnterpriseFeatures = {
|
||||||
|
sso_enforced_for_signin: boolean
|
||||||
|
sso_enforced_for_signin_protocol: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export const defaultEnterpriseFeatures: EnterpriseFeatures = {
|
||||||
|
sso_enforced_for_signin: false,
|
||||||
|
sso_enforced_for_signin_protocol: '',
|
||||||
|
}
|
Loading…
Reference in New Issue