Update `credit.py`

This commit is contained in:
Krzysztof Czerwinski 2024-12-31 13:54:43 +01:00
parent 76b5259f8d
commit b3d7804ba1
3 changed files with 24 additions and 23 deletions

View File

@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from fastapi.responses import RedirectResponse
import stripe
from prisma import Json
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import CreditTransaction, User
import stripe
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
@ -20,12 +19,12 @@ stripe.api_key = settings.secrets.stripe_api_key
class UserCreditBase(ABC):
@abstractmethod
async def get_or_refill_credit(self, user_id: str) -> int:
async def get_credits(self, user_id: str) -> int:
"""
Get the current credit for the user and refill if no transaction has been made in the current cycle.
Get the current credits for the user.
Returns:
int: The current credit for the user.
int: The current credits for the user.
"""
pass
@ -58,7 +57,7 @@ class UserCreditBase(ABC):
@abstractmethod
async def top_up_credits(self, user_id: str, amount: int):
"""
Top up the credits for the user.
Top up the credits for the user immediately.
Args:
user_id (str): The user ID.
@ -95,7 +94,7 @@ class UserCredit(UserCreditBase):
def __init__(self):
self.num_user_credits_refill = settings.config.num_user_credits_refill
async def get_or_refill_credit(self, user_id: str) -> int:
async def get_credits(self, user_id: str) -> int:
cur_time = self.time_now()
cur_month = cur_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
nxt_month = (
@ -176,8 +175,8 @@ class UserCredit(UserCreditBase):
) -> bool:
"""
Filter rules:
- If costFilter is an object, then check if costFilter is the subset of inputValues
- Otherwise, check if costFilter is equal to inputValues.
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
@ -230,10 +229,14 @@ class UserCredit(UserCreditBase):
return cost
async def top_up_credits(self, user_id: str, amount: int):
if amount < 0:
raise ValueError(f"Top up amount must not be negative: {amount}")
await CreditTransaction.prisma().create(
data={
"userId": user_id,
"amount": amount,
"isActive": True,
"type": CreditTransactionType.TOP_UP,
"createdAt": self.time_now(),
}
@ -278,7 +281,7 @@ class UserCredit(UserCreditBase):
# Create pending transaction
await CreditTransaction.prisma().create(
data={
"transactionKey": checkout_session.id, # TODO kcze add new model field?
"transactionKey": checkout_session.id,
"userId": user_id,
"amount": amount,
"type": CreditTransactionType.TOP_UP,
@ -291,7 +294,6 @@ class UserCredit(UserCreditBase):
# https://docs.stripe.com/checkout/fulfillment
async def fulfill_checkout(self, session_id):
print("fulfill_checkout", session_id)
# Retrieve CreditTransaction
credit_transaction = await CreditTransaction.prisma().find_first_or_raise(
where={"transactionKey": session_id}
@ -307,7 +309,6 @@ class UserCredit(UserCreditBase):
# Check the Checkout Session's payment_status property
# to determine if fulfillment should be peformed
if checkout_session.payment_status != "unpaid":
print("Payment status is not unpaid!")
# Activate the CreditTransaction
await CreditTransaction.prisma().update(
where={
@ -325,7 +326,7 @@ class UserCredit(UserCreditBase):
class DisabledUserCredit(UserCreditBase):
async def get_or_refill_credit(self, *args, **kwargs) -> int:
async def get_credits(self, *args, **kwargs) -> int:
return 0
async def spend_credits(self, *args, **kwargs) -> int:

View File

@ -80,7 +80,7 @@ class DatabaseManager(AppService):
user_credit_model = get_user_credit_model()
get_or_refill_credit = cast(
Callable[[Any, str], int],
exposed_run_and_wait(user_credit_model.get_or_refill_credit),
exposed_run_and_wait(user_credit_model.get_credits),
)
spend_credits = cast(
Callable[[Any, str, int, str, dict[str, str], float, float], int],

View File

@ -15,7 +15,7 @@ user_credit = UserCredit()
@pytest.mark.asyncio(scope="session")
async def test_block_credit_usage(server: SpinTestServer):
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
spending_amount_1 = await user_credit.spend_credits(
DEFAULT_USER_ID,
@ -46,17 +46,17 @@ async def test_block_credit_usage(server: SpinTestServer):
)
assert spending_amount_2 == 0
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
new_credit = await user_credit.get_credits(DEFAULT_USER_ID)
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
@pytest.mark.asyncio(scope="session")
async def test_block_credit_top_up(server: SpinTestServer):
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
new_credit = await user_credit.get_credits(DEFAULT_USER_ID)
assert new_credit == current_credit + 100
@ -66,17 +66,17 @@ async def test_block_credit_reset(server: SpinTestServer):
month2 = datetime(2022, 2, 15)
user_credit.time_now = lambda: month2
month2credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
month2credit = await user_credit.get_credits(DEFAULT_USER_ID)
# Month 1 result should only affect month 1
user_credit.time_now = lambda: month1
month1credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
month1credit = await user_credit.get_credits(DEFAULT_USER_ID)
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month1credit + 100
assert await user_credit.get_credits(DEFAULT_USER_ID) == month1credit + 100
# Month 2 balance is unaffected
user_credit.time_now = lambda: month2
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month2credit
assert await user_credit.get_credits(DEFAULT_USER_ID) == month2credit
@pytest.mark.asyncio(scope="session")
@ -94,5 +94,5 @@ async def test_credit_refill(server: SpinTestServer):
)
user_credit.time_now = lambda: datetime(2022, 2, 15)
balance = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
balance = await user_credit.get_credits(DEFAULT_USER_ID)
assert balance == REFILL_VALUE