添加OAuth2提供者的支持

This commit is contained in:
luojiaaoo 2025-03-18 23:51:27 +08:00
parent 2f57b5557d
commit 12bb8524d6
11 changed files with 423 additions and 28 deletions

View File

@ -24,6 +24,11 @@ class AuthException(Exception):
self.message = message
self.data = data
class OAuth2Error(Exception):
def __init__(self, description, status_code=400):
self.description = description
self.status_code = status_code
def global_exception_handler(error):
from dash import set_props

View File

@ -0,0 +1,49 @@
from common.exception import OAuth2Error
from flask import request
authorize_html = """
<p>The application <strong>{{client_id}}</strong> is requesting:
<strong>{{ scope }}</strong>
</p>
<p>
from You - a.k.a. <strong>{{ user_name }}</strong>
</p>
<form action="" method="post">
<label>
<input type="checkbox" name="confirm">
<span>Consent?</span>
</label>
<br>
<button>Submit</button>
</form>
"""
def require_oauth(required_scope):
from functools import wraps
def wrapper(func):
@wraps(func)
def decorator(*args, **kwargs):
token = current_token()
if not token:
raise OAuth2Error('Invalid access token', 401)
if not token.is_valid():
raise OAuth2Error('Token expired', 401)
if required_scope not in token.scope.split():
raise OAuth2Error('Invalid scope', 401)
func(*args, **kwargs)
return decorator
return wrapper
def current_token():
from database.sql_db.dao.dao_oauth2 import exist_token
auth_header = request.headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
raise OAuth2Error('Invalid token', 401)
return exist_token(auth_header.split(' ')[-1])

View File

@ -57,4 +57,10 @@ POOL_SIZE = 5
[ProxyConf]
NGINX_PROXY = False
[OAuth2Conf]
OAuth2AuthorizationCodeExpiresInMinutes = 60
OAuth2AuthorizationCodeLength = 32
OAuth2TokenExpiresInMinutes = 120
OAuth2TokenLength = 32
[SqlCacheConf]

View File

@ -7,7 +7,7 @@ class PathProj:
ROOT_PATH = Path(__file__).parent.parent
CONF_FILE_PATH = ROOT_PATH / 'config' / 'dashgo.ini'
AVATAR_DIR_PATH = (ROOT_PATH / '..' / 'user_data' / 'avatars').resolve()
AVATAR_DIR_PATH.mkdir(parents=True,exist_ok=True)
AVATAR_DIR_PATH.mkdir(parents=True, exist_ok=True)
conf = ConfigParser()
@ -64,9 +64,18 @@ class JwtConf(metaclass=BaseMetaConf):
JWT_ALGORITHM: str = 'HS256'
JWT_EXPIRE_MINUTES: int = 1440
class ProxyConf(metaclass=BaseMetaConf):
NGINX_PROXY: bool = False
class OAuth2Conf(metaclass=BaseMetaConf):
OAuth2AuthorizationCodeExpiresInMinutes: int
OAuth2AuthorizationCodeLength: int
OAuth2TokenExpiresInMinutes: int
OAuth2TokenLength: int
class SqlDbConf(metaclass=BaseMetaConf):
RDB_TYPE: str
SQLITE_DB_PATH: str

View File

@ -8,17 +8,24 @@ import dash
from i18n import t__other
# 定义一个客户端回调函数用于处理登录验证代码的显示逻辑总是显示login的路径
# 定义一个客户端回调函数用于处理登录验证代码的显示逻辑总是显示login的路径如果有next的query参数则代表是OAuth2的请求
app.clientside_callback(
"""
(fc_count,timeoutCount) => {
(fc_count,timeoutCount,url) => {
fc_count=fc_count || 0;
const urlObj = new URL(url);
const searchParams = new URLSearchParams(urlObj.search);
if (searchParams.has('next')) {
title = 'OAuth2 Login';
} else {
title = window.dash_clientside.no_update;
}
if (fc_count>="""
+ str(LoginConf.VERIFY_CODE_SHOW_LOGIN_FAIL_COUNT)
+ """) {
return [{'display':'flex'}, {'height': 'max(40%,600px)'}, 1, '/login'];
return [{'display':'flex'}, {'height': 'max(40%,600px)'}, 1, '/login', title];
}
return [{'display':'None'}, {'height': 'max(35%,500px)'}, 0, '/login'];
return [{'display':'None'}, {'height': 'max(35%,500px)'}, 0, '/login', title];
}
""",
[
@ -26,11 +33,13 @@ app.clientside_callback(
Output('login-container', 'style'),
Output('login-store-need-vc', 'data'),
Output('login-location-no-refresh', 'pathname'),
Output('login-title', 'children'),
],
[
Input('login-store-fc', 'data'),
Input('timeout-trigger-verify-code', 'timeoutCount'),
],
State('login-location-no-refresh', 'href'),
prevent_initial_call=True,
)
@ -227,7 +236,7 @@ def login(
if user_login(user_name, password_sha256, is_keep_login_status):
return (
dcc.Location(pathname='/dashboard_/workbench', refresh=True, id='index-redirect'),
dcc.Location(pathname='/', refresh=True, id='index-redirect'),
0, # 重置登录失败次数
dash.no_update,
dash.no_update,
@ -268,7 +277,7 @@ def otp_login(otp_value, user_name):
if totp.verify(int(otp_value)):
jwt_encode_save_access_to_session({'user_name': user_name}, session_permanent=False)
return (
dcc.Location(pathname='/dashboard_/workbench', refresh=True, id='index-redirect'),
dcc.Location(pathname='/', refresh=True, id='index-redirect'),
dash.no_update,
dash.no_update,
)

View File

@ -14,6 +14,7 @@ def render_content():
id='login-container',
children=[
fuc.FefferyDiv(
id='login-title',
children=ShowConf.APP_NAME,
className={
'fontWeight': 'bold',

View File

@ -1,7 +1,6 @@
from config.dashgo_conf import SqlDbConf
from playhouse.pool import PooledMySQLDatabase
from peewee import SqliteDatabase
from server import server
from playhouse.shortcuts import ReconnectMixin
@ -44,6 +43,7 @@ def initialize_database():
if not db_instance.table_exists('sys_user'):
from .entity.table_user import SysUser, SysRoleAccessMeta, SysUserRole, SysGroupUser, SysRole, SysGroupRole, SysGroup
from .entity.table_announcement import SysAnnouncement
from .entity.table_oauth2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
from datetime import datetime
import hashlib
@ -57,6 +57,9 @@ def initialize_database():
SysGroupRole,
SysGroup,
SysAnnouncement,
OAuth2Client,
OAuth2AuthorizationCode,
OAuth2Token
]
)
SysRole.create(
@ -86,14 +89,4 @@ def initialize_database():
SysUserRole.create(user_name='admin', role_name='admin')
# 自动管理数据库上下文
@server.before_request
def _db_connect():
db().connect(reuse_if_open=True)
@server.teardown_request
def _db_close(exc):
_db = db()
if not _db.is_closed():
_db.close()

View File

@ -0,0 +1,67 @@
from peewee import DoesNotExist
from database.sql_db.entity.table_oauth2 import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
from typing import Optional
def exist_client(client_id) -> Optional[OAuth2Client]:
try:
client: OAuth2Client = OAuth2Client.get(OAuth2Client.client_id == client_id)
return client
except DoesNotExist:
return None
def insert_authorization_code(code, client_id, user_name, redirect_uri, expires_at, scope) -> bool:
try:
OAuth2AuthorizationCode.create(
code=code,
client_id=client_id,
user_name=user_name,
redirect_uri=redirect_uri,
expires_at=expires_at,
scope=scope,
)
return True
except Exception as e:
return False
def exist_code(code, client_id) -> Optional[OAuth2AuthorizationCode]:
try:
code: OAuth2AuthorizationCode = OAuth2AuthorizationCode.get((OAuth2AuthorizationCode.code == code) & (OAuth2AuthorizationCode.client_id == client_id))
return code
except DoesNotExist:
return None
def validate_client(client_id, client_secret) -> Optional[OAuth2Client]:
"""验证客户端凭证"""
try:
client: OAuth2Client = OAuth2Client.get(OAuth2Client.client_id == client_id)
if client.client_secret == client_secret:
return client
return None
except DoesNotExist:
return None
def insert_token(token, client_id, user_name, expires_at, scope) -> bool:
try:
OAuth2Token.create(
token=token,
client_id=client_id,
user_name=user_name,
expires_at=expires_at,
scope=scope,
)
return True
except Exception as e:
return False
def exist_token(token) -> Optional[OAuth2Token]:
try:
token: OAuth2Token = OAuth2Token.get(OAuth2Token.token == token)
return token
except DoesNotExist:
return None

View File

@ -770,7 +770,6 @@ def get_dict_group_name_users_roles(user_name) -> Dict[str, Union[str, Set]]:
for group_name in group_names:
group_remark, user_names, group_roles = get_user_and_role_for_group_name(group_name=group_name)
print(group_roles)
user_infos = get_user_info(user_names=user_names)
dict_user_info = {i.user_name: i for i in user_infos}
for user_name_per, user_info in dict_user_info.items():

View File

@ -0,0 +1,77 @@
from peewee import Model, CharField, TextField, DateTimeField
from ..conn import db
from datetime import datetime
import secrets
class BaseModel(Model):
class Meta:
database = db()
class OAuth2Client(BaseModel):
"""注册的三方客户端信息"""
client_id = CharField(max_length=48, help_text='客户端ID')
client_secret = CharField(max_length=120, help_text='客户端密钥')
redirect_uris = TextField(help_text='允许的回调地址')
scope = TextField(help_text='权限范围')
class Meta:
indexes = ((('client_id',), True),)
# grant阶段验证
def check_redirect_uri(self, redirect_uri):
return redirect_uri in self.redirect_uris.split()
def check_scope(self, redirect_uris):
return set(redirect_uris).issubset(set(self.redirect_uris.split()))
# token阶段验证
def check_client_secret(self, client_secret):
return secrets.compare_digest(self.client_secret, client_secret)
def check_grant_type(self, grant_type):
return grant_type == 'authorization_code'
class OAuth2AuthorizationCode(BaseModel):
"""生成的随机授权码"""
code = CharField(max_length=120, help_text='授权码')
client_id = CharField(max_length=48, help_text='客户端ID')
user_name = CharField(max_length=32, help_text='用户名')
redirect_uri = CharField(max_length=120, help_text='回调地址')
expires_at = DateTimeField(help_text='过期时间')
scope = TextField(help_text='权限范围')
class Meta:
indexes = ((('code',), True),)
def is_valid(self):
return self.expires_at > datetime.now()
def check_redirect_uri(self, redirect_uri):
return redirect_uri == self.redirect_uri
def check_client_id(self, client_id):
return client_id == self.client_id
class OAuth2Token(BaseModel):
"""颁发的访问令牌"""
token = CharField(max_length=48, help_text='访问令牌')
client_id = CharField(max_length=48, help_text='客户端ID')
user_name = CharField(max_length=32, help_text='用户名')
expires_at = DateTimeField(help_text='过期时间')
scope = TextField(help_text='权限范围')
class Meta:
indexes = ((('token',), True),)
def is_valid(self):
from datetime import datetime
"""检查令牌是否有效"""
return self.expires_at > datetime.now()

View File

@ -1,9 +1,12 @@
from flask import request, redirect, send_from_directory, abort
from flask import request, redirect, send_from_directory, abort, jsonify
from common.exception import OAuth2Error
from common.utilities.util_oauth2 import current_token, require_oauth
from config.dashgo_conf import ShowConf, FlaskConf, CommonConf, PathProj
from common.utilities.util_logger import Log
from common.exception import global_exception_handler
from common.utilities.util_dash import CustomDash
from common.constant import HttpStatusConstant
from datetime import datetime, timedelta
from i18n import t__other
@ -33,7 +36,6 @@ server = app.server
def download_file(user_name):
file_name = f'{user_name}.jpg'
if '..' in user_name:
logger.warning(f'有人尝试通过头像文件接口攻击URL:{request.url}IP:{request.remote_addr}')
abort(HttpStatusConstant.FORBIDDEN)
else:
@ -49,18 +51,12 @@ def ban_bypass_proxy():
abort(HttpStatusConstant.FORBIDDEN)
# 首页重定向
@server.before_request
def main_page_redirct():
if request.path == '/':
return redirect('/dashboard_/workbench')
# 恶意访问管理页面拦截器
@server.before_request
def ban_admin():
if request.path.startswith('/admin'):
from common.utilities.util_browser import get_browser_info
browser_info = get_browser_info()
logger.warning(f'有人尝试访问不存在的管理页面URL:{browser_info.url}IP:{browser_info.request_addr}')
abort(HttpStatusConstant.NOT_FOUND)
@ -79,3 +75,187 @@ def get_user_agent_info():
browser_info.request_addr,
t__other('Chrome内核版本号太低请升级浏览器'),
)
# 自动管理数据库上下文
@server.before_request
def _db_connect():
from database.sql_db.conn import db
db().connect(reuse_if_open=True)
@server.teardown_request
def _db_close(exc):
from database.sql_db.conn import db
_db = db()
if not _db.is_closed():
_db.close()
@server.route('/oauth/authorize', methods=['GET', 'POST'])
def authorize():
"""第一步grant用户确认阶段
OAuth2授权端点
参数
- client_id: 客户端ID必须
- redirect_uri: 重定向URI必须
- response_type: 必须为'code'
- scope: 请求的权限范围可选本项目为必选
- state: 客户端状态值可选本项目为必选
"""
from flask import request, render_template_string
from common.utilities.util_jwt import jwt_decode_from_session, AccessFailType
from common.utilities.util_oauth2 import authorize_html
from config.dashgo_conf import OAuth2Conf
from database.sql_db.dao.dao_oauth2 import exist_client, insert_authorization_code
from yarl import URL
import secrets
# 1. 如果没登陆,登录了再来认证
if isinstance((rt_access := jwt_decode_from_session()), AccessFailType):
return redirect(URL.build(path='/login').with_query({'next': request.url}).__str__())
### 参数检查
user_name = rt_access['user_name']
# 检查client_id
if request.args.get('client_id') is None:
raise OAuth2Error('Invalid_client_id')
client = exist_client(client_id=request.args.get('client_id'))
# 检查scope
if request.args.get('scope') is None or client.check_scope(request.args.get('scope').split()):
raise OAuth2Error('Invalid_scope')
# 检查redirect_uri
if request.args.get('redirect_uri') is None or client.check_redirect_uri(request.args.get('redirect_uri')):
raise OAuth2Error('Invalid_redirect_uri')
# 检查response_type
if request.args.get('response_type') != 'code':
raise OAuth2Error('Invalid_response_type')
# 检查state
if request.args.get('state') is None:
raise OAuth2Error('Invalid_state')
# 2. 登录啦,是否同意授权?
if request.method == 'GET' and request.args.get('confirm') != 'yes':
return render_template_string(authorize_html, scope=request.args.get('scope'), client_id=request.args.get('client_id'), user_name=user_name)
# 3. 同意授权
grant_user = rt_access['user_name'] if request.form['confirm'] or request.args.get('confirm') == 'yes' else None
if grant_user is None:
raise OAuth2Error('NOT_IMPLEMENTED', HttpStatusConstant.NOT_IMPLEMENTED)
# 生成code授权码
if insert_authorization_code(
code=(auth_code := secrets.token_urlsafe(OAuth2Conf.OAuth2AuthorizationCodeLength)),
client_id=request.args.get('client_id'),
user_name=user_name,
expires_at=datetime.now() + timedelta(minutes=OAuth2Conf.OAuth2AuthorizationCodeExpiresInMinutes),
redirect_uri=request.args.get('redirect_uri'),
scope=request.args.get('scope'),
):
return redirect(URL(request.args.get('redirect_uri')).with_query({'code': auth_code, 'state': request.args.get('state')}).__str__())
else:
raise OAuth2Error('Internal error: Authorization code generation failed')
@server.route('/oauth/token', methods=['POST'])
def issue_token():
"""第三步token发放
OAuth2令牌端点
支持授权码模式authorization_code
参数
- grant_type: 必须为'authorization_code'
- code: 授权码必须
- redirect_uri: 必须与授权请求时一致
- client_id: 客户端ID必须
- client_secret: 客户端密钥必须
"""
import secrets
from database.sql_db.dao.dao_oauth2 import exist_code, validate_client, insert_token
from config.dashgo_conf import OAuth2Conf
# 验证客户端凭证
client_id = request.form.get('client_id')
client_secret = request.form.get('client_secret')
client = validate_client(client_id, client_secret)
if not client:
raise OAuth2Error('Invalid client credentials', HttpStatusConstant.UNAUTHORIZED)
# 验证授权类型
if request.form.get('grant_type') != 'authorization_code':
raise OAuth2Error('Unsupported grant_type')
# 获取授权码
code = request.form.get('code')
if not code:
raise OAuth2Error('Missing authorization code')
auth_code = exist_code(code=code, client_id=client_id)
if not auth_code:
raise OAuth2Error('Invalid authorization code', HttpStatusConstant.UNAUTHORIZED)
# 验证授权码
if not auth_code.is_valid():
raise OAuth2Error('Authorization code expired')
if not auth_code.check_client_id(client_id):
raise OAuth2Error('Client mismatch')
if not auth_code.check_redirect_uri(request.form.get('redirect_uri')):
raise OAuth2Error('Redirect URI mismatch')
if insert_token(
token=(token_ := secrets.token_urlsafe(OAuth2Conf.OAuth2TokenLength)),
client_id=client_id,
user_name=auth_code.user_name,
expires_at=datetime.now() + timedelta(minutes=OAuth2Conf.OAuth2TokenExpiresInMinutes),
scope=auth_code.scope,
):
auth_code.delete()
return jsonify(
{
'token_type': 'bearer',
'access_token': token_,
'expires_in': OAuth2Conf.OAuth2TokenExpiresInMinutes * 60,
'scope': auth_code.scope,
}
)
else:
raise OAuth2Error('Internal error: Token generation failed')
@server.route('/api/userinfo')
@require_oauth('userinfo')
def userinfo():
"""受保护端点"""
from database.sql_db.dao.dao_user import get_user_info
from common.utilities.util_menu_access import MenuAccess
token = current_token()
user = get_user_info(user_names=[token.user_name])[0]
access_metas = MenuAccess(token.user_name).all_access_metas
return jsonify(
{
'user_name': user.user_name,
'user_full_name': user.user_full_name,
'user_sex': user.user_sex,
'access_metas': access_metas,
'user_email': user.user_email,
'phone_number': user.phone_number,
'user_remark': user.user_remark,
}
)
# oauth2_grant登录后重定向
@server.before_request
def oauth2_grant_redirect():
from common.utilities.util_jwt import jwt_decode_from_session, AccessFailType
from yarl import URL
if not isinstance(jwt_decode_from_session(), AccessFailType) and request.path == '/' and request.args.get('next') is not None:
return redirect(URL(request.args.get('next')).extend_query(confirm='yes').__str__())
# OAuth2错误处理器
@server.errorhandler(OAuth2Error)
def handle_oauth2_error(e):
return jsonify({'error': 'invalid_request', 'error_description': e.description}), e.status_code
# 首页重定向
@server.before_request
def main_page_redirct():
if request.path == '/':
return redirect('/dashboard_/workbench')