diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index a9c4300b9a..2ee9c7c468 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -6,9 +6,13 @@ from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api -from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError -from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError -from controllers.console.wraps import setup_required +from controllers.console.auth.error import (EmailCodeError, InvalidEmailError, + InvalidTokenError, + PasswordMismatchError) +from controllers.console.error import (AccountInFreezeError, AccountNotFound, + EmailSendIpLimitError) +from controllers.console.wraps import (email_password_login_enabled, + setup_required) from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip @@ -22,6 +26,7 @@ from services.feature_service import FeatureService class ForgotPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") @@ -53,6 +58,7 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=str, required=True, location="json") @@ -72,11 +78,20 @@ class ForgotPasswordCheckApi(Resource): if args["code"] != token_data.get("code"): raise EmailCodeError() - return {"is_valid": True, "email": token_data.get("email")} + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} class ForgotPasswordResetApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, nullable=False, location="json") @@ -95,6 +110,9 @@ class ForgotPasswordResetApi(Resource): if reset_data is None: raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() AccountService.revoke_reset_password_token(token) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 41362e9fa2..16c1dcc441 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,7 @@ from controllers.console.error import ( EmailSendIpLimitError, NotAllowedCreateWorkspace, ) -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password @@ -38,6 +38,7 @@ class LoginApi(Resource): """Resource for user login.""" @setup_required + @email_password_login_enabled def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -110,6 +111,7 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 510ea242a9..11b7140b1b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -11,7 +11,8 @@ from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService -from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout +from .error import (NotInitValidateError, NotSetupError, + UnauthorizedAndForceLogout) def account_initialization_required(view): @@ -165,3 +166,16 @@ def enterprise_license_required(view): return view(*args, **kwargs) return decorated + + +def email_password_login_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_email_password_login: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/services/account_service.py b/api/services/account_service.py index dd1cc5f94f..2fbc446984 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -406,10 +406,8 @@ class AccountService: raise PasswordResetRateLimitExceededError() - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - token = TokenManager.generate_token( - account=account, email=email, token_type="reset_password", additional_data={"code": code} - ) + code, token = cls.generate_reset_password_token(account_email, account) + send_reset_password_mail_task.delay( language=language, to=account_email, @@ -418,6 +416,22 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def generate_reset_password_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password")