Update `credit.py`
This commit is contained in:
parent
76b5259f8d
commit
b3d7804ba1
|
@ -1,12 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi.responses import RedirectResponse
|
||||
import stripe
|
||||
from prisma import Json
|
||||
from prisma.enums import CreditTransactionType
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import CreditTransaction, User
|
||||
import stripe
|
||||
|
||||
from backend.data.block import Block, BlockInput, get_block
|
||||
from backend.data.block_cost_config import BLOCK_COSTS
|
||||
|
@ -20,12 +19,12 @@ stripe.api_key = settings.secrets.stripe_api_key
|
|||
|
||||
class UserCreditBase(ABC):
|
||||
@abstractmethod
|
||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
||||
async def get_credits(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current credit for the user and refill if no transaction has been made in the current cycle.
|
||||
Get the current credits for the user.
|
||||
|
||||
Returns:
|
||||
int: The current credit for the user.
|
||||
int: The current credits for the user.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -58,7 +57,7 @@ class UserCreditBase(ABC):
|
|||
@abstractmethod
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
"""
|
||||
Top up the credits for the user.
|
||||
Top up the credits for the user immediately.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
|
@ -95,7 +94,7 @@ class UserCredit(UserCreditBase):
|
|||
def __init__(self):
|
||||
self.num_user_credits_refill = settings.config.num_user_credits_refill
|
||||
|
||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
||||
async def get_credits(self, user_id: str) -> int:
|
||||
cur_time = self.time_now()
|
||||
cur_month = cur_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
nxt_month = (
|
||||
|
@ -176,8 +175,8 @@ class UserCredit(UserCreditBase):
|
|||
) -> bool:
|
||||
"""
|
||||
Filter rules:
|
||||
- If costFilter is an object, then check if costFilter is the subset of inputValues
|
||||
- Otherwise, check if costFilter is equal to inputValues.
|
||||
- If cost_filter is an object, then check if cost_filter is the subset of input_data
|
||||
- Otherwise, check if cost_filter is equal to input_data.
|
||||
- Undefined, null, and empty string are considered as equal.
|
||||
"""
|
||||
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
||||
|
@ -230,10 +229,14 @@ class UserCredit(UserCreditBase):
|
|||
return cost
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
if amount < 0:
|
||||
raise ValueError(f"Top up amount must not be negative: {amount}")
|
||||
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"isActive": True,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
|
@ -278,7 +281,7 @@ class UserCredit(UserCreditBase):
|
|||
# Create pending transaction
|
||||
await CreditTransaction.prisma().create(
|
||||
data={
|
||||
"transactionKey": checkout_session.id, # TODO kcze add new model field?
|
||||
"transactionKey": checkout_session.id,
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"type": CreditTransactionType.TOP_UP,
|
||||
|
@ -291,7 +294,6 @@ class UserCredit(UserCreditBase):
|
|||
|
||||
# https://docs.stripe.com/checkout/fulfillment
|
||||
async def fulfill_checkout(self, session_id):
|
||||
print("fulfill_checkout", session_id)
|
||||
# Retrieve CreditTransaction
|
||||
credit_transaction = await CreditTransaction.prisma().find_first_or_raise(
|
||||
where={"transactionKey": session_id}
|
||||
|
@ -307,7 +309,6 @@ class UserCredit(UserCreditBase):
|
|||
# Check the Checkout Session's payment_status property
|
||||
# to determine if fulfillment should be peformed
|
||||
if checkout_session.payment_status != "unpaid":
|
||||
print("Payment status is not unpaid!")
|
||||
# Activate the CreditTransaction
|
||||
await CreditTransaction.prisma().update(
|
||||
where={
|
||||
|
@ -325,7 +326,7 @@ class UserCredit(UserCreditBase):
|
|||
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_or_refill_credit(self, *args, **kwargs) -> int:
|
||||
async def get_credits(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
|
||||
async def spend_credits(self, *args, **kwargs) -> int:
|
||||
|
|
|
@ -80,7 +80,7 @@ class DatabaseManager(AppService):
|
|||
user_credit_model = get_user_credit_model()
|
||||
get_or_refill_credit = cast(
|
||||
Callable[[Any, str], int],
|
||||
exposed_run_and_wait(user_credit_model.get_or_refill_credit),
|
||||
exposed_run_and_wait(user_credit_model.get_credits),
|
||||
)
|
||||
spend_credits = cast(
|
||||
Callable[[Any, str, int, str, dict[str, str], float, float], int],
|
||||
|
|
|
@ -15,7 +15,7 @@ user_credit = UserCredit()
|
|||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_usage(server: SpinTestServer):
|
||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
|
||||
spending_amount_1 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
|
@ -46,17 +46,17 @@ async def test_block_credit_usage(server: SpinTestServer):
|
|||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
new_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_top_up(server: SpinTestServer):
|
||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
|
||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||
|
||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
new_credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert new_credit == current_credit + 100
|
||||
|
||||
|
||||
|
@ -66,17 +66,17 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||
month2 = datetime(2022, 2, 15)
|
||||
|
||||
user_credit.time_now = lambda: month2
|
||||
month2credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
|
||||
# Month 1 result should only affect month 1
|
||||
user_credit.time_now = lambda: month1
|
||||
month1credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month1credit + 100
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
|
||||
|
||||
# Month 2 balance is unaffected
|
||||
user_credit.time_now = lambda: month2
|
||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month2credit
|
||||
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
|
@ -94,5 +94,5 @@ async def test_credit_refill(server: SpinTestServer):
|
|||
)
|
||||
user_credit.time_now = lambda: datetime(2022, 2, 15)
|
||||
|
||||
balance = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE
|
||||
|
|
Loading…
Reference in New Issue