添加OAuth2提供者的支持
This commit is contained in:
parent
2f57b5557d
commit
12bb8524d6
|
@ -24,6 +24,11 @@ class AuthException(Exception):
|
|||
self.message = message
|
||||
self.data = data
|
||||
|
||||
class OAuth2Error(Exception):
|
||||
def __init__(self, description, status_code=400):
|
||||
self.description = description
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
def global_exception_handler(error):
|
||||
from dash import set_props
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
from common.exception import OAuth2Error
|
||||
from flask import request
|
||||
|
||||
authorize_html = """
|
||||
<p>The application <strong>{{client_id}}</strong> is requesting:
|
||||
<strong>{{ scope }}</strong>
|
||||
</p>
|
||||
|
||||
<p>
|
||||
from You - a.k.a. <strong>{{ user_name }}</strong>
|
||||
</p>
|
||||
|
||||
<form action="" method="post">
|
||||
<label>
|
||||
<input type="checkbox" name="confirm">
|
||||
<span>Consent?</span>
|
||||
</label>
|
||||
<br>
|
||||
<button>Submit</button>
|
||||
</form>
|
||||
"""
|
||||
|
||||
def require_oauth(required_scope):
|
||||
from functools import wraps
|
||||
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
def decorator(*args, **kwargs):
|
||||
token = current_token()
|
||||
if not token:
|
||||
raise OAuth2Error('Invalid access token', 401)
|
||||
if not token.is_valid():
|
||||
raise OAuth2Error('Token expired', 401)
|
||||
if required_scope not in token.scope.split():
|
||||
raise OAuth2Error('Invalid scope', 401)
|
||||
func(*args, **kwargs)
|
||||
|
||||
return decorator
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def current_token():
|
||||
from database.sql_db.dao.dao_oauth2 import exist_token
|
||||
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
raise OAuth2Error('Invalid token', 401)
|
||||
return exist_token(auth_header.split(' ')[-1])
|
|
@ -57,4 +57,10 @@ POOL_SIZE = 5
|
|||
[ProxyConf]
|
||||
NGINX_PROXY = False
|
||||
|
||||
[OAuth2Conf]
|
||||
OAuth2AuthorizationCodeExpiresInMinutes = 60
|
||||
OAuth2AuthorizationCodeLength = 32
|
||||
OAuth2TokenExpiresInMinutes = 120
|
||||
OAuth2TokenLength = 32
|
||||
|
||||
[SqlCacheConf]
|
||||
|
|
|
@ -7,7 +7,7 @@ class PathProj:
|
|||
ROOT_PATH = Path(__file__).parent.parent
|
||||
CONF_FILE_PATH = ROOT_PATH / 'config' / 'dashgo.ini'
|
||||
AVATAR_DIR_PATH = (ROOT_PATH / '..' / 'user_data' / 'avatars').resolve()
|
||||
AVATAR_DIR_PATH.mkdir(parents=True,exist_ok=True)
|
||||
AVATAR_DIR_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
conf = ConfigParser()
|
||||
|
@ -64,9 +64,18 @@ class JwtConf(metaclass=BaseMetaConf):
|
|||
JWT_ALGORITHM: str = 'HS256'
|
||||
JWT_EXPIRE_MINUTES: int = 1440
|
||||
|
||||
|
||||
class ProxyConf(metaclass=BaseMetaConf):
|
||||
NGINX_PROXY: bool = False
|
||||
|
||||
|
||||
class OAuth2Conf(metaclass=BaseMetaConf):
|
||||
OAuth2AuthorizationCodeExpiresInMinutes: int
|
||||
OAuth2AuthorizationCodeLength: int
|
||||
OAuth2TokenExpiresInMinutes: int
|
||||
OAuth2TokenLength: int
|
||||
|
||||
|
||||
class SqlDbConf(metaclass=BaseMetaConf):
|
||||
RDB_TYPE: str
|
||||
SQLITE_DB_PATH: str
|
||||
|
|
|
@ -8,17 +8,24 @@ import dash
|
|||
from i18n import t__other
|
||||
|
||||
|
||||
# 定义一个客户端回调函数,用于处理登录验证代码的显示逻辑,总是显示login的路径
|
||||
# 定义一个客户端回调函数,用于处理登录验证代码的显示逻辑,总是显示login的路径,如果有next的query参数,则代表是OAuth2的请求
|
||||
app.clientside_callback(
|
||||
"""
|
||||
(fc_count,timeoutCount) => {
|
||||
(fc_count,timeoutCount,url) => {
|
||||
fc_count=fc_count || 0;
|
||||
const urlObj = new URL(url);
|
||||
const searchParams = new URLSearchParams(urlObj.search);
|
||||
if (searchParams.has('next')) {
|
||||
title = 'OAuth2 Login';
|
||||
} else {
|
||||
title = window.dash_clientside.no_update;
|
||||
}
|
||||
if (fc_count>="""
|
||||
+ str(LoginConf.VERIFY_CODE_SHOW_LOGIN_FAIL_COUNT)
|
||||
+ """) {
|
||||
return [{'display':'flex'}, {'height': 'max(40%,600px)'}, 1, '/login'];
|
||||
return [{'display':'flex'}, {'height': 'max(40%,600px)'}, 1, '/login', title];
|
||||
}
|
||||
return [{'display':'None'}, {'height': 'max(35%,500px)'}, 0, '/login'];
|
||||
return [{'display':'None'}, {'height': 'max(35%,500px)'}, 0, '/login', title];
|
||||
}
|
||||
""",
|
||||
[
|
||||
|
@ -26,11 +33,13 @@ app.clientside_callback(
|
|||
Output('login-container', 'style'),
|
||||
Output('login-store-need-vc', 'data'),
|
||||
Output('login-location-no-refresh', 'pathname'),
|
||||
Output('login-title', 'children'),
|
||||
],
|
||||
[
|
||||
Input('login-store-fc', 'data'),
|
||||
Input('timeout-trigger-verify-code', 'timeoutCount'),
|
||||
],
|
||||
State('login-location-no-refresh', 'href'),
|
||||
prevent_initial_call=True,
|
||||
)
|
||||
|
||||
|
@ -227,7 +236,7 @@ def login(
|
|||
|
||||
if user_login(user_name, password_sha256, is_keep_login_status):
|
||||
return (
|
||||
dcc.Location(pathname='/dashboard_/workbench', refresh=True, id='index-redirect'),
|
||||
dcc.Location(pathname='/', refresh=True, id='index-redirect'),
|
||||
0, # 重置登录失败次数
|
||||
dash.no_update,
|
||||
dash.no_update,
|
||||
|
@ -268,7 +277,7 @@ def otp_login(otp_value, user_name):
|
|||
if totp.verify(int(otp_value)):
|
||||
jwt_encode_save_access_to_session({'user_name': user_name}, session_permanent=False)
|
||||
return (
|
||||
dcc.Location(pathname='/dashboard_/workbench', refresh=True, id='index-redirect'),
|
||||
dcc.Location(pathname='/', refresh=True, id='index-redirect'),
|
||||
dash.no_update,
|
||||
dash.no_update,
|
||||
)
|
||||
|
|
|
@ -14,6 +14,7 @@ def render_content():
|
|||
id='login-container',
|
||||
children=[
|
||||
fuc.FefferyDiv(
|
||||
id='login-title',
|
||||
children=ShowConf.APP_NAME,
|
||||
className={
|
||||
'fontWeight': 'bold',
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from config.dashgo_conf import SqlDbConf
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
from peewee import SqliteDatabase
|
||||
from server import server
|
||||
from playhouse.shortcuts import ReconnectMixin
|
||||
|
||||
|
||||
|
@ -44,6 +43,7 @@ def initialize_database():
|
|||
if not db_instance.table_exists('sys_user'):
|
||||
from .entity.table_user import SysUser, SysRoleAccessMeta, SysUserRole, SysGroupUser, SysRole, SysGroupRole, SysGroup
|
||||
from .entity.table_announcement import SysAnnouncement
|
||||
from .entity.table_oauth2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
|
@ -57,6 +57,9 @@ def initialize_database():
|
|||
SysGroupRole,
|
||||
SysGroup,
|
||||
SysAnnouncement,
|
||||
OAuth2Client,
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2Token
|
||||
]
|
||||
)
|
||||
SysRole.create(
|
||||
|
@ -86,14 +89,4 @@ def initialize_database():
|
|||
SysUserRole.create(user_name='admin', role_name='admin')
|
||||
|
||||
|
||||
# 自动管理数据库上下文
|
||||
@server.before_request
|
||||
def _db_connect():
|
||||
db().connect(reuse_if_open=True)
|
||||
|
||||
|
||||
@server.teardown_request
|
||||
def _db_close(exc):
|
||||
_db = db()
|
||||
if not _db.is_closed():
|
||||
_db.close()
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
from peewee import DoesNotExist
|
||||
from database.sql_db.entity.table_oauth2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def exist_client(client_id) -> Optional[OAuth2Client]:
|
||||
try:
|
||||
client: OAuth2Client = OAuth2Client.get(OAuth2Client.client_id == client_id)
|
||||
return client
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def insert_authorization_code(code, client_id, user_name, redirect_uri, expires_at, scope) -> bool:
|
||||
try:
|
||||
OAuth2AuthorizationCode.create(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_name=user_name,
|
||||
redirect_uri=redirect_uri,
|
||||
expires_at=expires_at,
|
||||
scope=scope,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
|
||||
def exist_code(code, client_id) -> Optional[OAuth2AuthorizationCode]:
|
||||
try:
|
||||
code: OAuth2AuthorizationCode = OAuth2AuthorizationCode.get((OAuth2AuthorizationCode.code == code) & (OAuth2AuthorizationCode.client_id == client_id))
|
||||
return code
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def validate_client(client_id, client_secret) -> Optional[OAuth2Client]:
|
||||
"""验证客户端凭证"""
|
||||
try:
|
||||
client: OAuth2Client = OAuth2Client.get(OAuth2Client.client_id == client_id)
|
||||
if client.client_secret == client_secret:
|
||||
return client
|
||||
return None
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def insert_token(token, client_id, user_name, expires_at, scope) -> bool:
|
||||
try:
|
||||
OAuth2Token.create(
|
||||
token=token,
|
||||
client_id=client_id,
|
||||
user_name=user_name,
|
||||
expires_at=expires_at,
|
||||
scope=scope,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
|
||||
def exist_token(token) -> Optional[OAuth2Token]:
|
||||
try:
|
||||
token: OAuth2Token = OAuth2Token.get(OAuth2Token.token == token)
|
||||
return token
|
||||
except DoesNotExist:
|
||||
return None
|
|
@ -770,7 +770,6 @@ def get_dict_group_name_users_roles(user_name) -> Dict[str, Union[str, Set]]:
|
|||
|
||||
for group_name in group_names:
|
||||
group_remark, user_names, group_roles = get_user_and_role_for_group_name(group_name=group_name)
|
||||
print(group_roles)
|
||||
user_infos = get_user_info(user_names=user_names)
|
||||
dict_user_info = {i.user_name: i for i in user_infos}
|
||||
for user_name_per, user_info in dict_user_info.items():
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
from peewee import Model, CharField, TextField, DateTimeField
|
||||
from ..conn import db
|
||||
from datetime import datetime
|
||||
import secrets
|
||||
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = db()
|
||||
|
||||
|
||||
class OAuth2Client(BaseModel):
|
||||
"""注册的三方客户端信息"""
|
||||
|
||||
client_id = CharField(max_length=48, help_text='客户端ID')
|
||||
client_secret = CharField(max_length=120, help_text='客户端密钥')
|
||||
redirect_uris = TextField(help_text='允许的回调地址')
|
||||
scope = TextField(help_text='权限范围')
|
||||
|
||||
class Meta:
|
||||
indexes = ((('client_id',), True),)
|
||||
|
||||
# grant阶段验证
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return redirect_uri in self.redirect_uris.split()
|
||||
|
||||
def check_scope(self, redirect_uris):
|
||||
return set(redirect_uris).issubset(set(self.redirect_uris.split()))
|
||||
|
||||
# token阶段验证
|
||||
def check_client_secret(self, client_secret):
|
||||
return secrets.compare_digest(self.client_secret, client_secret)
|
||||
|
||||
def check_grant_type(self, grant_type):
|
||||
return grant_type == 'authorization_code'
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(BaseModel):
|
||||
"""生成的随机授权码"""
|
||||
|
||||
code = CharField(max_length=120, help_text='授权码')
|
||||
client_id = CharField(max_length=48, help_text='客户端ID')
|
||||
user_name = CharField(max_length=32, help_text='用户名')
|
||||
redirect_uri = CharField(max_length=120, help_text='回调地址')
|
||||
expires_at = DateTimeField(help_text='过期时间')
|
||||
scope = TextField(help_text='权限范围')
|
||||
|
||||
class Meta:
|
||||
indexes = ((('code',), True),)
|
||||
|
||||
def is_valid(self):
|
||||
return self.expires_at > datetime.now()
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return redirect_uri == self.redirect_uri
|
||||
|
||||
def check_client_id(self, client_id):
|
||||
return client_id == self.client_id
|
||||
|
||||
|
||||
class OAuth2Token(BaseModel):
|
||||
"""颁发的访问令牌"""
|
||||
|
||||
token = CharField(max_length=48, help_text='访问令牌')
|
||||
client_id = CharField(max_length=48, help_text='客户端ID')
|
||||
user_name = CharField(max_length=32, help_text='用户名')
|
||||
expires_at = DateTimeField(help_text='过期时间')
|
||||
scope = TextField(help_text='权限范围')
|
||||
|
||||
class Meta:
|
||||
indexes = ((('token',), True),)
|
||||
|
||||
def is_valid(self):
|
||||
from datetime import datetime
|
||||
|
||||
"""检查令牌是否有效"""
|
||||
return self.expires_at > datetime.now()
|
198
src/server.py
198
src/server.py
|
@ -1,9 +1,12 @@
|
|||
from flask import request, redirect, send_from_directory, abort
|
||||
from flask import request, redirect, send_from_directory, abort, jsonify
|
||||
from common.exception import OAuth2Error
|
||||
from common.utilities.util_oauth2 import current_token, require_oauth
|
||||
from config.dashgo_conf import ShowConf, FlaskConf, CommonConf, PathProj
|
||||
from common.utilities.util_logger import Log
|
||||
from common.exception import global_exception_handler
|
||||
from common.utilities.util_dash import CustomDash
|
||||
from common.constant import HttpStatusConstant
|
||||
from datetime import datetime, timedelta
|
||||
from i18n import t__other
|
||||
|
||||
|
||||
|
@ -33,7 +36,6 @@ server = app.server
|
|||
def download_file(user_name):
|
||||
file_name = f'{user_name}.jpg'
|
||||
if '..' in user_name:
|
||||
|
||||
logger.warning(f'有人尝试通过头像文件接口攻击,URL:{request.url},IP:{request.remote_addr}')
|
||||
abort(HttpStatusConstant.FORBIDDEN)
|
||||
else:
|
||||
|
@ -49,18 +51,12 @@ def ban_bypass_proxy():
|
|||
abort(HttpStatusConstant.FORBIDDEN)
|
||||
|
||||
|
||||
# 首页重定向
|
||||
@server.before_request
|
||||
def main_page_redirct():
|
||||
if request.path == '/':
|
||||
return redirect('/dashboard_/workbench')
|
||||
|
||||
|
||||
# 恶意访问管理页面拦截器
|
||||
@server.before_request
|
||||
def ban_admin():
|
||||
if request.path.startswith('/admin'):
|
||||
from common.utilities.util_browser import get_browser_info
|
||||
|
||||
browser_info = get_browser_info()
|
||||
logger.warning(f'有人尝试访问不存在的管理页面,URL:{browser_info.url},IP:{browser_info.request_addr}')
|
||||
abort(HttpStatusConstant.NOT_FOUND)
|
||||
|
@ -79,3 +75,187 @@ def get_user_agent_info():
|
|||
browser_info.request_addr,
|
||||
t__other('Chrome内核版本号太低,请升级浏览器'),
|
||||
)
|
||||
|
||||
|
||||
# 自动管理数据库上下文
|
||||
@server.before_request
|
||||
def _db_connect():
|
||||
from database.sql_db.conn import db
|
||||
|
||||
db().connect(reuse_if_open=True)
|
||||
|
||||
|
||||
@server.teardown_request
|
||||
def _db_close(exc):
|
||||
from database.sql_db.conn import db
|
||||
|
||||
_db = db()
|
||||
if not _db.is_closed():
|
||||
_db.close()
|
||||
|
||||
|
||||
@server.route('/oauth/authorize', methods=['GET', 'POST'])
|
||||
def authorize():
|
||||
"""第一步grant用户确认阶段
|
||||
OAuth2授权端点
|
||||
参数:
|
||||
- client_id: 客户端ID(必须)
|
||||
- redirect_uri: 重定向URI(必须)
|
||||
- response_type: 必须为'code'
|
||||
- scope: 请求的权限范围(可选,本项目为必选)
|
||||
- state: 客户端状态值(可选,本项目为必选)
|
||||
"""
|
||||
from flask import request, render_template_string
|
||||
from common.utilities.util_jwt import jwt_decode_from_session, AccessFailType
|
||||
from common.utilities.util_oauth2 import authorize_html
|
||||
from config.dashgo_conf import OAuth2Conf
|
||||
from database.sql_db.dao.dao_oauth2 import exist_client, insert_authorization_code
|
||||
from yarl import URL
|
||||
import secrets
|
||||
|
||||
# 1. 如果没登陆,登录了再来认证
|
||||
if isinstance((rt_access := jwt_decode_from_session()), AccessFailType):
|
||||
return redirect(URL.build(path='/login').with_query({'next': request.url}).__str__())
|
||||
### 参数检查
|
||||
user_name = rt_access['user_name']
|
||||
# 检查client_id
|
||||
if request.args.get('client_id') is None:
|
||||
raise OAuth2Error('Invalid_client_id')
|
||||
client = exist_client(client_id=request.args.get('client_id'))
|
||||
# 检查scope
|
||||
if request.args.get('scope') is None or client.check_scope(request.args.get('scope').split()):
|
||||
raise OAuth2Error('Invalid_scope')
|
||||
# 检查redirect_uri
|
||||
if request.args.get('redirect_uri') is None or client.check_redirect_uri(request.args.get('redirect_uri')):
|
||||
raise OAuth2Error('Invalid_redirect_uri')
|
||||
# 检查response_type
|
||||
if request.args.get('response_type') != 'code':
|
||||
raise OAuth2Error('Invalid_response_type')
|
||||
# 检查state
|
||||
if request.args.get('state') is None:
|
||||
raise OAuth2Error('Invalid_state')
|
||||
# 2. 登录啦,是否同意授权?
|
||||
if request.method == 'GET' and request.args.get('confirm') != 'yes':
|
||||
return render_template_string(authorize_html, scope=request.args.get('scope'), client_id=request.args.get('client_id'), user_name=user_name)
|
||||
|
||||
# 3. 同意授权
|
||||
grant_user = rt_access['user_name'] if request.form['confirm'] or request.args.get('confirm') == 'yes' else None
|
||||
if grant_user is None:
|
||||
raise OAuth2Error('NOT_IMPLEMENTED', HttpStatusConstant.NOT_IMPLEMENTED)
|
||||
# 生成code授权码
|
||||
if insert_authorization_code(
|
||||
code=(auth_code := secrets.token_urlsafe(OAuth2Conf.OAuth2AuthorizationCodeLength)),
|
||||
client_id=request.args.get('client_id'),
|
||||
user_name=user_name,
|
||||
expires_at=datetime.now() + timedelta(minutes=OAuth2Conf.OAuth2AuthorizationCodeExpiresInMinutes),
|
||||
redirect_uri=request.args.get('redirect_uri'),
|
||||
scope=request.args.get('scope'),
|
||||
):
|
||||
return redirect(URL(request.args.get('redirect_uri')).with_query({'code': auth_code, 'state': request.args.get('state')}).__str__())
|
||||
else:
|
||||
raise OAuth2Error('Internal error: Authorization code generation failed')
|
||||
|
||||
|
||||
@server.route('/oauth/token', methods=['POST'])
|
||||
def issue_token():
|
||||
"""第三步token发放
|
||||
OAuth2令牌端点
|
||||
支持授权码模式(authorization_code)
|
||||
参数:
|
||||
- grant_type: 必须为'authorization_code'
|
||||
- code: 授权码(必须)
|
||||
- redirect_uri: 必须与授权请求时一致
|
||||
- client_id: 客户端ID(必须)
|
||||
- client_secret: 客户端密钥(必须)
|
||||
"""
|
||||
import secrets
|
||||
from database.sql_db.dao.dao_oauth2 import exist_code, validate_client, insert_token
|
||||
from config.dashgo_conf import OAuth2Conf
|
||||
|
||||
# 验证客户端凭证
|
||||
client_id = request.form.get('client_id')
|
||||
client_secret = request.form.get('client_secret')
|
||||
client = validate_client(client_id, client_secret)
|
||||
if not client:
|
||||
raise OAuth2Error('Invalid client credentials', HttpStatusConstant.UNAUTHORIZED)
|
||||
# 验证授权类型
|
||||
if request.form.get('grant_type') != 'authorization_code':
|
||||
raise OAuth2Error('Unsupported grant_type')
|
||||
# 获取授权码
|
||||
code = request.form.get('code')
|
||||
if not code:
|
||||
raise OAuth2Error('Missing authorization code')
|
||||
auth_code = exist_code(code=code, client_id=client_id)
|
||||
if not auth_code:
|
||||
raise OAuth2Error('Invalid authorization code', HttpStatusConstant.UNAUTHORIZED)
|
||||
# 验证授权码
|
||||
if not auth_code.is_valid():
|
||||
raise OAuth2Error('Authorization code expired')
|
||||
if not auth_code.check_client_id(client_id):
|
||||
raise OAuth2Error('Client mismatch')
|
||||
if not auth_code.check_redirect_uri(request.form.get('redirect_uri')):
|
||||
raise OAuth2Error('Redirect URI mismatch')
|
||||
if insert_token(
|
||||
token=(token_ := secrets.token_urlsafe(OAuth2Conf.OAuth2TokenLength)),
|
||||
client_id=client_id,
|
||||
user_name=auth_code.user_name,
|
||||
expires_at=datetime.now() + timedelta(minutes=OAuth2Conf.OAuth2TokenExpiresInMinutes),
|
||||
scope=auth_code.scope,
|
||||
):
|
||||
auth_code.delete()
|
||||
return jsonify(
|
||||
{
|
||||
'token_type': 'bearer',
|
||||
'access_token': token_,
|
||||
'expires_in': OAuth2Conf.OAuth2TokenExpiresInMinutes * 60,
|
||||
'scope': auth_code.scope,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise OAuth2Error('Internal error: Token generation failed')
|
||||
|
||||
|
||||
@server.route('/api/userinfo')
|
||||
@require_oauth('userinfo')
|
||||
def userinfo():
|
||||
"""受保护端点"""
|
||||
from database.sql_db.dao.dao_user import get_user_info
|
||||
from common.utilities.util_menu_access import MenuAccess
|
||||
|
||||
token = current_token()
|
||||
user = get_user_info(user_names=[token.user_name])[0]
|
||||
access_metas = MenuAccess(token.user_name).all_access_metas
|
||||
return jsonify(
|
||||
{
|
||||
'user_name': user.user_name,
|
||||
'user_full_name': user.user_full_name,
|
||||
'user_sex': user.user_sex,
|
||||
'access_metas': access_metas,
|
||||
'user_email': user.user_email,
|
||||
'phone_number': user.phone_number,
|
||||
'user_remark': user.user_remark,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# oauth2_grant登录后重定向
|
||||
@server.before_request
|
||||
def oauth2_grant_redirect():
|
||||
from common.utilities.util_jwt import jwt_decode_from_session, AccessFailType
|
||||
from yarl import URL
|
||||
|
||||
if not isinstance(jwt_decode_from_session(), AccessFailType) and request.path == '/' and request.args.get('next') is not None:
|
||||
return redirect(URL(request.args.get('next')).extend_query(confirm='yes').__str__())
|
||||
|
||||
|
||||
# OAuth2错误处理器
|
||||
@server.errorhandler(OAuth2Error)
|
||||
def handle_oauth2_error(e):
|
||||
return jsonify({'error': 'invalid_request', 'error_description': e.description}), e.status_code
|
||||
|
||||
|
||||
# 首页重定向
|
||||
@server.before_request
|
||||
def main_page_redirct():
|
||||
if request.path == '/':
|
||||
return redirect('/dashboard_/workbench')
|
||||
|
|
Loading…
Reference in New Issue