Merge branch 'dev' into kpczerwinski/secrt-1012-mvp-implement-top-up-flow

This commit is contained in:
Krzysztof Czerwinski 2024-12-12 16:43:12 +01:00
commit 2481d3d88f
47 changed files with 424 additions and 232 deletions

View File

@ -12,6 +12,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class ImageSize(str, Enum):
@ -101,12 +102,10 @@ class ImageGenModel(str, Enum):
class AIImageGeneratorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="Enter your Replicate API key to access the image generation API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@ -13,6 +13,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
@ -54,13 +55,11 @@ class NormalizationStrategy(str, Enum):
class AIMusicGeneratorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="A description of the music you want to generate",

View File

@ -12,6 +12,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@ -140,13 +141,11 @@ logger = logging.getLogger(__name__)
class AIShortformVideoCreatorBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["revid"], Literal["api_key"]] = (
CredentialsField(
provider="revid",
supported_credential_types={"api_key"},
description="The revid.ai integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REVID], Literal["api_key"]
] = CredentialsField(
description="The revid.ai integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
script: str = SchemaField(
description="""1. Use short and punctuated sentences\n\n2. Use linebreaks to create a new clip\n\n3. Text outside of brackets is spoken by the AI, and [text between brackets] will be used to guide the visual generation. For example, [close-up of a cat] will show a close-up of a cat.""",

View File

@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@ -39,12 +40,10 @@ class CodeExecutionBlock(Block):
# TODO : Add support to upload and download files
# Currently, You can customized the CPU and Memory, only by creating a pre customized sandbox template
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["e2b"], Literal["api_key"]] = (
CredentialsField(
provider="e2b",
supported_credential_types={"api_key"},
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.E2B], Literal["api_key"]
] = CredentialsField(
description="Enter your api key for the E2B Sandbox. You can get it in here - https://e2b.dev/docs",
)
# Todo : Option to run commond in background

View File

@ -12,16 +12,15 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
DiscordCredentials = CredentialsMetaInput[Literal["discord"], Literal["api_key"]]
DiscordCredentials = CredentialsMetaInput[
Literal[ProviderName.DISCORD], Literal["api_key"]
]
def DiscordCredentialsField() -> DiscordCredentials:
return CredentialsField(
description="Discord bot token",
provider="discord",
supported_credential_types={"api_key"},
)
return CredentialsField(description="Discord bot token")
TEST_CREDENTIALS = APIKeyCredentials(

View File

@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
ExaCredentials = APIKeyCredentials
ExaCredentialsInput = CredentialsMetaInput[
Literal["exa"],
Literal[ProviderName.EXA],
Literal["api_key"],
]
@ -28,8 +29,4 @@ TEST_CREDENTIALS_INPUT = {
def ExaCredentialsField() -> ExaCredentialsInput:
"""Creates an Exa credentials input on a block."""
return CredentialsField(
provider="exa",
supported_credential_types={"api_key"},
description="The Exa integration requires an API Key.",
)
return CredentialsField(description="The Exa integration requires an API Key.")

View File

@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
FalCredentials = APIKeyCredentials
FalCredentialsInput = CredentialsMetaInput[
Literal["fal"],
Literal[ProviderName.FAL],
Literal["api_key"],
]
@ -30,7 +31,5 @@ def FalCredentialsField() -> FalCredentialsInput:
Creates a FAL credentials input on a block.
"""
return CredentialsField(
provider="fal",
supported_credential_types={"api_key"},
description="The FAL integration can be used with an API Key.",
)

View File

@ -8,6 +8,7 @@ from backend.data.model import (
CredentialsMetaInput,
OAuth2Credentials,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
secrets = Secrets()
@ -17,7 +18,7 @@ GITHUB_OAUTH_IS_CONFIGURED = bool(
GithubCredentials = APIKeyCredentials | OAuth2Credentials
GithubCredentialsInput = CredentialsMetaInput[
Literal["github"],
Literal[ProviderName.GITHUB],
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
]
@ -30,10 +31,6 @@ def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
""" # noqa
return CredentialsField(
provider="github",
supported_credential_types=(
{"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"}
),
required_scopes={scope},
description="The GitHub integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",

View File

@ -3,6 +3,7 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import CredentialsField, CredentialsMetaInput, OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
# --8<-- [start:GoogleOAuthIsConfigured]
@ -12,7 +13,9 @@ GOOGLE_OAUTH_IS_CONFIGURED = bool(
)
# --8<-- [end:GoogleOAuthIsConfigured]
GoogleCredentials = OAuth2Credentials
GoogleCredentialsInput = CredentialsMetaInput[Literal["google"], Literal["oauth2"]]
GoogleCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.GOOGLE], Literal["oauth2"]
]
def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
@ -23,8 +26,6 @@ def GoogleCredentialsField(scopes: list[str]) -> GoogleCredentialsInput:
scopes: The authorization scopes needed for the block to work.
"""
return CredentialsField(
provider="google",
supported_credential_types={"oauth2"},
required_scopes=set(scopes),
description="The Google integration requires OAuth2 authentication.",
)

View File

@ -10,6 +10,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@ -38,12 +39,8 @@ class Place(BaseModel):
class GoogleMapsSearchBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal["google_maps"], Literal["api_key"]
] = CredentialsField(
provider="google_maps",
supported_credential_types={"api_key"},
description="Google Maps API Key",
)
Literal[ProviderName.GOOGLE_MAPS], Literal["api_key"]
] = CredentialsField(description="Google Maps API Key")
query: str = SchemaField(
description="Search query for local businesses",
placeholder="e.g., 'restaurants in New York'",

View File

@ -3,10 +3,11 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
HubSpotCredentials = APIKeyCredentials
HubSpotCredentialsInput = CredentialsMetaInput[
Literal["hubspot"],
Literal[ProviderName.HUBSPOT],
Literal["api_key"],
]
@ -14,8 +15,6 @@ HubSpotCredentialsInput = CredentialsMetaInput[
def HubSpotCredentialsField() -> HubSpotCredentialsInput:
"""Creates a HubSpot credentials input on a block."""
return CredentialsField(
provider="hubspot",
supported_credential_types={"api_key"},
description="The HubSpot integration requires an API Key.",
)

View File

@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@ -83,13 +84,10 @@ class UpscaleOption(str, Enum):
class IdeogramModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["ideogram"], Literal["api_key"]] = (
CredentialsField(
provider="ideogram",
supported_credential_types={"api_key"},
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.IDEOGRAM], Literal["api_key"]
] = CredentialsField(
description="The Ideogram integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@ -3,27 +3,14 @@ from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
JinaCredentials = APIKeyCredentials
JinaCredentialsInput = CredentialsMetaInput[
Literal["jina"],
Literal[ProviderName.JINA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="jina",
api_key=SecretStr("mock-jina-api-key"),
title="Mock Jina API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
def JinaCredentialsField() -> JinaCredentialsInput:
"""
@ -31,8 +18,6 @@ def JinaCredentialsField() -> JinaCredentialsInput:
"""
return CredentialsField(
provider="jina",
supported_credential_types={"api_key"},
description="The Jina integration can be used with an API Key.",
)

View File

@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple
from pydantic import SecretStr
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
from enum import _EnumMemberT
@ -27,7 +29,13 @@ from backend.util.settings import BehaveAs, Settings
logger = logging.getLogger(__name__)
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama", "open_router"]
LLMProviderName = Literal[
ProviderName.ANTHROPIC,
ProviderName.GROQ,
ProviderName.OLLAMA,
ProviderName.OPENAI,
ProviderName.OPEN_ROUTER,
]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]
TEST_CREDENTIALS = APIKeyCredentials(
@ -48,8 +56,6 @@ TEST_CREDENTIALS_INPUT = {
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider=["anthropic", "groq", "openai", "ollama", "open_router"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel

View File

@ -12,6 +12,7 @@ from backend.data.model import (
SchemaField,
SecretField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@ -77,12 +78,10 @@ class PublishToMediumBlock(Block):
description="Whether to notify followers that the user has published",
placeholder="False",
)
credentials: CredentialsMetaInput[Literal["medium"], Literal["api_key"]] = (
CredentialsField(
provider="medium",
supported_credential_types={"api_key"},
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.MEDIUM], Literal["api_key"]
] = CredentialsField(
description="The Medium integration can be used with any API key with sufficient permissions for the blocks it is used on.",
)
class Output(BlockSchema):

View File

@ -10,22 +10,18 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
PineconeCredentials = APIKeyCredentials
PineconeCredentialsInput = CredentialsMetaInput[
Literal["pinecone"],
Literal[ProviderName.PINECONE],
Literal["api_key"],
]
def PineconeCredentialsField() -> PineconeCredentialsInput:
"""
Creates a Pinecone credentials input on a block.
"""
"""Creates a Pinecone credentials input on a block."""
return CredentialsField(
provider="pinecone",
supported_credential_types={"api_key"},
description="The Pinecone integration can be used with an API Key.",
)

View File

@ -13,6 +13,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
@ -54,13 +55,11 @@ class ImageType(str, Enum):
class ReplicateFluxAdvancedModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["replicate"], Literal["api_key"]] = (
CredentialsField(
provider="replicate",
supported_credential_types={"api_key"},
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
prompt: str = SchemaField(
description="Text prompt for image generation",

View File

@ -11,6 +11,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
class GetWikipediaSummaryBlock(Block, GetRequest):
@ -65,10 +66,8 @@ class GetWeatherInformationBlock(Block, GetRequest):
description="Location to get weather information for"
)
credentials: CredentialsMetaInput[
Literal["openweathermap"], Literal["api_key"]
Literal[ProviderName.OPENWEATHERMAP], Literal["api_key"]
] = CredentialsField(
provider="openweathermap",
supported_credential_types={"api_key"},
description="The OpenWeatherMap integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)

View File

@ -4,16 +4,15 @@ from typing import Literal
from pydantic import BaseModel, SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
Slant3DCredentialsInput = CredentialsMetaInput[Literal["slant3d"], Literal["api_key"]]
Slant3DCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.SLANT3D], Literal["api_key"]
]
def Slant3DCredentialsField() -> Slant3DCredentialsInput:
return CredentialsField(
provider="slant3d",
supported_credential_types={"api_key"},
description="Slant3D API key for authentication",
)
return CredentialsField(description="Slant3D API key for authentication")
TEST_CREDENTIALS = APIKeyCredentials(

View File

@ -10,6 +10,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@ -29,13 +30,11 @@ TEST_CREDENTIALS_INPUT = {
class CreateTalkingAvatarVideoBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[Literal["d_id"], Literal["api_key"]] = (
CredentialsField(
provider="d_id",
supported_credential_types={"api_key"},
description="The D-ID integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
credentials: CredentialsMetaInput[
Literal[ProviderName.D_ID], Literal["api_key"]
] = CredentialsField(
description="The D-ID integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
script_input: str = SchemaField(
description="The text input for the script",

View File

@ -9,6 +9,7 @@ from backend.data.model import (
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.request import requests
TEST_CREDENTIALS = APIKeyCredentials(
@ -38,10 +39,8 @@ class UnrealTextToSpeechBlock(Block):
default="Scarlett",
)
credentials: CredentialsMetaInput[
Literal["unreal_speech"], Literal["api_key"]
Literal[ProviderName.UNREAL_SPEECH], Literal["api_key"]
] = CredentialsField(
provider="unreal_speech",
supported_credential_types={"api_key"},
description="The Unreal Speech integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)

View File

@ -65,7 +65,7 @@ class BlockCategory(Enum):
class BlockSchema(BaseModel):
cached_jsonschema: ClassVar[dict[str, Any]] = {}
cached_jsonschema: ClassVar[dict[str, Any]]
@classmethod
def jsonschema(cls) -> dict[str, Any]:
@ -145,6 +145,10 @@ class BlockSchema(BaseModel):
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
"""
super().__pydantic_init_subclass__(**kwargs)
# Reset cached JSON schema to prevent inheriting it from parent class
cls.cached_jsonschema = {}
credentials_fields = [
field_name
for field_name, info in cls.model_fields.items()
@ -176,6 +180,11 @@ class BlockSchema(BaseModel):
f"Field 'credentials' on {cls.__qualname__} "
f"must be of type {CredentialsMetaInput.__name__}"
)
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
credentials_input_type = cast(
CredentialsMetaInput, credentials_field.annotation
)
credentials_input_type.validate_credentials_field_schema(cls)
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)

View File

@ -2,9 +2,9 @@ from abc import ABC, abstractmethod
from datetime import datetime, timezone
from prisma import Json
from prisma.enums import UserBlockCreditType
from prisma.enums import CreditTransactionType
from prisma.errors import UniqueViolationError
from prisma.models import UserBlockCredit
from prisma.models import CreditTransaction
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
@ -76,7 +76,7 @@ class UserCredit(UserCreditBase):
else cur_month.replace(year=cur_month.year + 1, month=1)
)
user_credit = await UserBlockCredit.prisma().group_by(
user_credit = await CreditTransaction.prisma().group_by(
by=["userId"],
sum={"amount": True},
where={
@ -93,10 +93,10 @@ class UserCredit(UserCreditBase):
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
try:
await UserBlockCredit.prisma().create(
await CreditTransaction.prisma().create(
data={
"amount": self.num_user_credits_refill,
"type": UserBlockCreditType.TOP_UP,
"type": CreditTransactionType.TOP_UP,
"userId": user_id,
"transactionKey": key,
"createdAt": self.time_now(),
@ -184,11 +184,11 @@ class UserCredit(UserCreditBase):
if validate_balance and user_credit < cost:
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
await UserBlockCredit.prisma().create(
await CreditTransaction.prisma().create(
data={
"userId": user_id,
"amount": -cost,
"type": UserBlockCreditType.USAGE,
"type": CreditTransactionType.USAGE,
"blockId": block.id,
"metadata": Json(
{
@ -202,11 +202,11 @@ class UserCredit(UserCreditBase):
return cost
async def top_up_credits(self, user_id: str, amount: int):
await UserBlockCredit.prisma().create(
await CreditTransaction.prisma().create(
data={
"userId": user_id,
"amount": amount,
"type": UserBlockCreditType.TOP_UP,
"type": CreditTransactionType.TOP_UP,
"createdAt": self.time_now(),
}
)

View File

@ -7,6 +7,7 @@ from pydantic import Field
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
from backend.data.queue import AsyncRedisEventBus
from backend.integrations.providers import ProviderName
from .db import BaseDbModel
@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
class Webhook(BaseDbModel):
user_id: str
provider: str
provider: ProviderName
credentials_id: str
webhook_type: str
resource: str
@ -37,7 +38,7 @@ class Webhook(BaseDbModel):
return Webhook(
id=webhook.id,
user_id=webhook.userId,
provider=webhook.provider,
provider=ProviderName(webhook.provider),
credentials_id=webhook.credentialsId,
webhook_type=webhook.webhookType,
resource=webhook.resource,
@ -61,7 +62,7 @@ async def create_webhook(webhook: Webhook) -> Webhook:
data={
"id": webhook.id,
"userId": webhook.user_id,
"provider": webhook.provider,
"provider": webhook.provider.value,
"credentialsId": webhook.credentials_id,
"webhookType": webhook.webhook_type,
"resource": webhook.resource,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
@ -11,19 +12,32 @@ from typing import (
Optional,
TypedDict,
TypeVar,
get_args,
)
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler, SecretStr, field_serializer
from pydantic import (
BaseModel,
ConfigDict,
Field,
GetCoreSchemaHandler,
SecretStr,
field_serializer,
)
from pydantic_core import (
CoreSchema,
PydanticUndefined,
PydanticUndefinedType,
ValidationError,
core_schema,
)
from backend.integrations.providers import ProviderName
from backend.util.settings import Secrets
if TYPE_CHECKING:
from backend.data.block import BlockSchema
T = TypeVar("T")
logger = logging.getLogger(__name__)
@ -220,7 +234,7 @@ class UserIntegrations(BaseModel):
oauth_states: list[OAuthState] = Field(default_factory=list)
CP = TypeVar("CP", bound=str)
CP = TypeVar("CP", bound=ProviderName)
CT = TypeVar("CT", bound=CredentialsType)
@ -233,19 +247,51 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT
@staticmethod
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
schema["credentials_provider"] = get_args(
cls.model_fields["provider"].annotation
)
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
class CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
model_config = ConfigDict(
json_schema_extra=_add_json_schema_extra, # type: ignore
)
@classmethod
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
"""Validates the schema of a `credentials` field"""
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
try:
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
field_schema
)
except ValidationError as e:
if "Field required [type=missing" not in str(e):
raise
raise TypeError(
"Field 'credentials' JSON schema lacks required extra items: "
f"{field_schema}"
) from e
if (
len(schema_extra.credentials_provider) > 1
and not schema_extra.discriminator
):
raise TypeError("Multi-provider CredentialsField requires discriminator!")
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
# TODO: move discrimination mechanism out of CredentialsField (frontend + backend)
credentials_provider: list[CP]
credentials_scopes: Optional[list[str]]
credentials_scopes: Optional[list[str]] = None
credentials_types: list[CT]
discriminator: Optional[str] = None
discriminator_mapping: Optional[dict[str, CP]] = None
def CredentialsField(
provider: CP | list[CP],
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
@ -253,26 +299,26 @@ def CredentialsField(
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
) -> CredentialsMetaInput[CP, CT]:
) -> CredentialsMetaInput:
"""
`CredentialsField` must and can only be used on fields named `credentials`.
This is enforced by the `BlockSchema` base class.
"""
if not isinstance(provider, str) and len(provider) > 1 and not discriminator:
raise TypeError("Multi-provider CredentialsField requires discriminator!")
field_schema_extra = CredentialsFieldSchemaExtra[CP, CT](
credentials_provider=[provider] if isinstance(provider, str) else provider,
credentials_scopes=list(required_scopes) or None, # omit if empty
credentials_types=list(supported_credential_types),
discriminator=discriminator,
discriminator_mapping=discriminator_mapping,
)
field_schema_extra = {
k: v
for k, v in {
"credentials_scopes": list(required_scopes) or None,
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}
return Field(
title=title,
description=description,
json_schema_extra=field_schema_extra.model_dump(exclude_none=True),
json_schema_extra=field_schema_extra, # validated on BlockSchema init
**kwargs,
)

View File

@ -1,6 +1,7 @@
import logging
from contextlib import contextmanager
from datetime import datetime
from typing import TYPE_CHECKING
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from redis.lock import Lock as RedisLock
@ -8,10 +9,13 @@ from redis.lock import Lock as RedisLock
from backend.data import redis
from backend.data.model import Credentials
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
logger = logging.getLogger(__name__)
settings = Settings()
@ -148,7 +152,7 @@ class IntegrationCredentialsManager:
self.store.locks.release_all_locks()
def _get_provider_oauth_handler(provider_name: str) -> BaseOAuthHandler:
def _get_provider_oauth_handler(provider_name: str) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise KeyError(f"Unknown provider '{provider_name}'")

View File

@ -1,10 +1,15 @@
from .base import BaseOAuthHandler
from typing import TYPE_CHECKING
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseOAuthHandler
# --8<-- [start:HANDLERS_BY_NAMEExample]
HANDLERS_BY_NAME: dict[str, type[BaseOAuthHandler]] = {
HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GitHubOAuthHandler,

View File

@ -4,13 +4,14 @@ from abc import ABC, abstractmethod
from typing import ClassVar
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
class BaseOAuthHandler(ABC):
# --8<-- [start:BaseOAuthHandler1]
PROVIDER_NAME: ClassVar[str]
PROVIDER_NAME: ClassVar[ProviderName]
DEFAULT_SCOPES: ClassVar[list[str]] = []
# --8<-- [end:BaseOAuthHandler1]
@ -76,6 +77,8 @@ class BaseOAuthHandler(ABC):
"""Handles the default scopes for the provider"""
# If scopes are empty, use the default scopes for the provider
if not scopes:
logger.debug(f"Using default scopes for provider {self.PROVIDER_NAME}")
logger.debug(
f"Using default scopes for provider {self.PROVIDER_NAME.value}"
)
scopes = self.DEFAULT_SCOPES
return scopes

View File

@ -3,6 +3,7 @@ from typing import Optional
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import requests
from .base import BaseOAuthHandler
@ -23,7 +24,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
access token *with no refresh token*.
""" # noqa
PROVIDER_NAME = "github"
PROVIDER_NAME = ProviderName.GITHUB
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id

View File

@ -9,6 +9,7 @@ from google_auth_oauthlib.flow import Flow
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from .base import BaseOAuthHandler
@ -21,7 +22,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
""" # noqa
PROVIDER_NAME = "google"
PROVIDER_NAME = ProviderName.GOOGLE
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
DEFAULT_SCOPES = [
"https://www.googleapis.com/auth/userinfo.email",

View File

@ -2,6 +2,7 @@ from base64 import b64encode
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.util.request import requests
from .base import BaseOAuthHandler
@ -16,7 +17,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
- Notion doesn't use scopes
"""
PROVIDER_NAME = "notion"
PROVIDER_NAME = ProviderName.NOTION
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id

View File

@ -1,7 +1,30 @@
from enum import Enum
# --8<-- [start:ProviderName]
class ProviderName(str, Enum):
ANTHROPIC = "anthropic"
DISCORD = "discord"
D_ID = "d_id"
E2B = "e2b"
EXA = "exa"
FAL = "fal"
GITHUB = "github"
GOOGLE = "google"
GOOGLE_MAPS = "google_maps"
GROQ = "groq"
HUBSPOT = "hubspot"
IDEOGRAM = "ideogram"
JINA = "jina"
MEDIUM = "medium"
NOTION = "notion"
OLLAMA = "ollama"
OPENAI = "openai"
OPENWEATHERMAP = "openweathermap"
OPEN_ROUTER = "open_router"
PINECONE = "pinecone"
REPLICATE = "replicate"
REVID = "revid"
SLANT3D = "slant3d"
UNREAL_SPEECH = "unreal_speech"
# --8<-- [end:ProviderName]

View File

@ -4,10 +4,11 @@ from .github import GithubWebhooksManager
from .slant3d import Slant3DWebhooksManager
if TYPE_CHECKING:
from ..providers import ProviderName
from .base import BaseWebhooksManager
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
WEBHOOK_MANAGERS_BY_NAME: dict[str, type["BaseWebhooksManager"]] = {
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
handler.PROVIDER_NAME: handler
for handler in [
GithubWebhooksManager,

View File

@ -9,6 +9,7 @@ from strenum import StrEnum
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
from backend.util.settings import Config
@ -20,7 +21,7 @@ WT = TypeVar("WT", bound=StrEnum)
class BaseWebhooksManager(ABC, Generic[WT]):
# --8<-- [start:BaseWebhooksManager1]
PROVIDER_NAME: ClassVar[str]
PROVIDER_NAME: ClassVar[ProviderName]
# --8<-- [end:BaseWebhooksManager1]
WebhookType: WT
@ -143,7 +144,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
secret = secrets.token_hex(32)
provider_name = self.PROVIDER_NAME
ingress_url = (
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
f"/webhooks/{id}/ingress"
)
provider_webhook_id, config = await self._register_webhook(

View File

@ -8,6 +8,7 @@ from strenum import StrEnum
from backend.data import integrations
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from .base import BaseWebhooksManager
@ -20,7 +21,7 @@ class GithubWebhookType(StrEnum):
class GithubWebhooksManager(BaseWebhooksManager):
PROVIDER_NAME = "github"
PROVIDER_NAME = ProviderName.GITHUB
WebhookType = GithubWebhookType

View File

@ -95,11 +95,18 @@ async def on_node_activate(
if not block.webhook_config:
return node
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
logger.debug(
f"Activating webhook node #{node.id} with config {block.webhook_config}"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
try:
resource = block.webhook_config.resource_format.format(**node.input_default)
@ -167,7 +174,14 @@ async def on_node_deactivate(
if not block.webhook_config:
return node
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[block.webhook_config.provider]()
provider = block.webhook_config.provider
if provider not in WEBHOOK_MANAGERS_BY_NAME:
raise ValueError(
f"Block #{block.id} has webhook_config for provider {provider} "
"which does not support webhooks"
)
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
if node.webhook_id:
logger.debug(f"Node #{node.id} has webhook_id {node.webhook_id}")
@ -189,7 +203,7 @@ async def on_node_deactivate(
logger.warning(
f"Cannot deregister webhook #{webhook.id}: credentials "
f"#{webhook.credentials_id} not available "
f"({webhook.provider} webhook ID: {webhook.provider_webhook_id})"
f"({webhook.provider.value} webhook ID: {webhook.provider_webhook_id})"
)
return updated_node

View File

@ -1,11 +1,11 @@
import logging
from typing import ClassVar
import requests
from fastapi import Request
from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.base import BaseWebhooksManager
logger = logging.getLogger(__name__)
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class Slant3DWebhooksManager(BaseWebhooksManager):
"""Manager for Slant3D webhooks"""
PROVIDER_NAME: ClassVar[str] = "slant3d"
PROVIDER_NAME = ProviderName.SLANT3D
BASE_URL = "https://www.slant3dapi.com/api"
async def _register_webhook(

View File

@ -1,5 +1,5 @@
import logging
from typing import Annotated, Literal
from typing import TYPE_CHECKING, Annotated, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr
@ -20,12 +20,16 @@ from backend.data.model import (
)
from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.integrations.oauth import HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
from backend.util.exceptions import NeedConfirmation
from backend.util.service import get_service_client
from backend.util.settings import Settings
if TYPE_CHECKING:
from backend.integrations.oauth import BaseOAuthHandler
from ..utils import get_user_id
logger = logging.getLogger(__name__)
@ -42,7 +46,9 @@ class LoginResponse(BaseModel):
@router.get("/{provider}/login")
def login(
provider: Annotated[str, Path(title="The provider to initiate an OAuth flow for")],
provider: Annotated[
ProviderName, Path(title="The provider to initiate an OAuth flow for")
],
user_id: Annotated[str, Depends(get_user_id)],
request: Request,
scopes: Annotated[
@ -74,7 +80,9 @@ class CredentialsMetaResponse(BaseModel):
@router.post("/{provider}/callback")
def callback(
provider: Annotated[str, Path(title="The target provider for this OAuth exchange")],
provider: Annotated[
ProviderName, Path(title="The target provider for this OAuth exchange")
],
code: Annotated[str, Body(title="Authorization code acquired by user login")],
state_token: Annotated[str, Body(title="Anti-CSRF nonce")],
user_id: Annotated[str, Depends(get_user_id)],
@ -103,11 +111,12 @@ def callback(
if not set(scopes).issubset(set(credentials.scopes)):
# For now, we'll just log the warning and continue
logger.warning(
f"Granted scopes {credentials.scopes} for {provider}do not include all requested scopes {scopes}"
f"Granted scopes {credentials.scopes} for provider {provider.value} "
f"do not include all requested scopes {scopes}"
)
except Exception as e:
logger.error(f"Code->Token exchange failed for provider {provider}: {e}")
logger.error(f"Code->Token exchange failed for provider {provider.value}: {e}")
raise HTTPException(
status_code=400, detail=f"Failed to exchange code for tokens: {str(e)}"
)
@ -116,7 +125,8 @@ def callback(
creds_manager.create(user_id, credentials)
logger.debug(
f"Successfully processed OAuth callback for user {user_id} and provider {provider}"
f"Successfully processed OAuth callback for user {user_id} "
f"and provider {provider.value}"
)
return CredentialsMetaResponse(
id=credentials.id,
@ -148,7 +158,9 @@ def list_credentials(
@router.get("/{provider}/credentials")
def list_credentials_by_provider(
provider: Annotated[str, Path(title="The provider to list credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to list credentials for")
],
user_id: Annotated[str, Depends(get_user_id)],
) -> list[CredentialsMetaResponse]:
credentials = creds_manager.store.get_creds_by_provider(user_id, provider)
@ -167,7 +179,9 @@ def list_credentials_by_provider(
@router.get("/{provider}/credentials/{cred_id}")
def get_credential(
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to retrieve credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
user_id: Annotated[str, Depends(get_user_id)],
) -> Credentials:
@ -184,7 +198,9 @@ def get_credential(
@router.post("/{provider}/credentials", status_code=201)
def create_api_key_credentials(
user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[str, Path(title="The provider to create credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to create credentials for")
],
api_key: Annotated[str, Body(title="The API key to store")],
title: Annotated[str, Body(title="Optional title for the credentials")],
expires_at: Annotated[
@ -225,7 +241,9 @@ class CredentialsDeletionNeedsConfirmationResponse(BaseModel):
@router.delete("/{provider}/credentials/{cred_id}")
async def delete_credentials(
request: Request,
provider: Annotated[str, Path(title="The provider to delete credentials for")],
provider: Annotated[
ProviderName, Path(title="The provider to delete credentials for")
],
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
user_id: Annotated[str, Depends(get_user_id)],
force: Annotated[
@ -264,15 +282,20 @@ async def delete_credentials(
@router.post("/{provider}/webhooks/{webhook_id}/ingress")
async def webhook_ingress_generic(
request: Request,
provider: Annotated[str, Path(title="Provider where the webhook was registered")],
provider: Annotated[
ProviderName, Path(title="Provider where the webhook was registered")
],
webhook_id: Annotated[str, Path(title="Our ID for the webhook")],
):
logger.debug(f"Received {provider} webhook ingress for ID {webhook_id}")
logger.debug(f"Received {provider.value} webhook ingress for ID {webhook_id}")
webhook_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
webhook = await get_webhook(webhook_id)
logger.debug(f"Webhook #{webhook_id}: {webhook}")
payload, event_type = await webhook_manager.validate_payload(webhook, request)
logger.debug(f"Validated {provider} {event_type} event with payload {payload}")
logger.debug(
f"Validated {provider.value} {webhook.webhook_type} {event_type} event "
f"with payload {payload}"
)
webhook_event = WebhookEvent(
provider=provider,
@ -341,6 +364,14 @@ async def remove_all_webhooks_for_credentials(
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
"""
webhooks = await get_all_webhooks(credentials.id)
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
if webhooks:
logger.error(
f"Credentials #{credentials.id} for provider {credentials.provider} "
f"are attached to {len(webhooks)} webhooks, "
f"but there is no available WebhooksHandler for {credentials.provider}"
)
return
if any(w.attached_nodes for w in webhooks) and not force:
raise NeedConfirmation(
"Some webhooks linked to these credentials are still in use by an agent"
@ -359,18 +390,23 @@ async def remove_all_webhooks_for_credentials(
logger.warning(f"Webhook #{webhook.id} failed to prune")
def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHandler:
def _get_provider_oauth_handler(
req: Request, provider_name: ProviderName
) -> "BaseOAuthHandler":
if provider_name not in HANDLERS_BY_NAME:
raise HTTPException(
status_code=404, detail=f"Unknown provider '{provider_name}'"
status_code=404,
detail=f"Provider '{provider_name.value}' does not support OAuth",
)
client_id = getattr(settings.secrets, f"{provider_name}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name}_client_secret")
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
client_secret = getattr(settings.secrets, f"{provider_name.value}_client_secret")
if not (client_id and client_secret):
raise HTTPException(
status_code=501,
detail=f"Integration with provider '{provider_name}' is not configured",
detail=(
f"Integration with provider '{provider_name.value}' is not configured"
),
)
handler_class = HANDLERS_BY_NAME[provider_name]

View File

@ -257,6 +257,13 @@ async def do_create_graph(
async def delete_graph(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> DeleteGraphResponse:
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
await on_graph_deactivate(active_version, get_credentials)
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}

View File

@ -0,0 +1,8 @@
-- AlterTable
ALTER TABLE "User" ADD COLUMN "stripeCustomerId" TEXT;
-- AlterEnum
ALTER TYPE "UserBlockCreditType" RENAME TO "CreditTransactionType";
-- AlterTable
ALTER TABLE "UserBlockCredit" RENAME TO "CreditTransaction";

View File

@ -12,13 +12,14 @@ generator client {
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
metadata Json @default("{}")
integrations String @default("")
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
metadata Json @default("{}")
integrations String @default("")
stripeCustomerId String?
// Relations
AgentGraphs AgentGraph[]
@ -26,7 +27,7 @@ model User {
IntegrationWebhooks IntegrationWebhook[]
AnalyticsDetails AnalyticsDetails[]
AnalyticsMetrics AnalyticsMetrics[]
UserBlockCredit UserBlockCredit[]
CreditTransaction CreditTransaction[]
APIKeys APIKey[]
@@index([id])
@ -123,7 +124,7 @@ model AgentBlock {
// Prisma requires explicit back-references.
ReferencedByAgentNode AgentNode[]
UserBlockCredit UserBlockCredit[]
CreditTransaction CreditTransaction[]
}
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
@ -275,12 +276,12 @@ model AnalyticsMetrics {
@@index([userId])
}
enum UserBlockCreditType {
enum CreditTransactionType {
TOP_UP
USAGE
}
model UserBlockCredit {
model CreditTransaction {
transactionKey String @default(uuid())
createdAt DateTime @default(now())
@ -291,7 +292,7 @@ model UserBlockCredit {
block AgentBlock? @relation(fields: [blockId], references: [id])
amount Int
type UserBlockCreditType
type CreditTransactionType
isActive Boolean @default(true)
metadata Json?

View File

@ -1,7 +1,7 @@
from datetime import datetime
import pytest
from prisma.models import UserBlockCredit
from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.credit import UserCredit
@ -82,7 +82,7 @@ async def test_block_credit_reset(server: SpinTestServer):
@pytest.mark.asyncio(scope="session")
async def test_credit_refill(server: SpinTestServer):
# Clear all transactions within the month
await UserBlockCredit.prisma().update_many(
await CreditTransaction.prisma().update_many(
where={
"userId": DEFAULT_USER_ID,
"createdAt": {

View File

@ -673,6 +673,7 @@ const FlowEditor: React.FC<{
blocks={availableNodes}
addBlock={addNode}
flows={availableFlows}
nodes={nodes}
/>
}
botChildren={

View File

@ -1,10 +1,11 @@
import React, { useState, useCallback } from "react";
import React, { useState, useMemo } from "react";
import { Card, CardContent, CardHeader } from "@/components/ui/card";
import { Label } from "@/components/ui/label";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { TextRenderer } from "@/components/ui/render";
import { ScrollArea } from "@/components/ui/scroll-area";
import { CustomNode } from "@/components/CustomNode";
import { beautifyString } from "@/lib/utils";
import {
Popover,
@ -31,6 +32,7 @@ interface BlocksControlProps {
) => void;
pinBlocksPopover: boolean;
flows: GraphMeta[];
nodes: CustomNode[];
}
/**
@ -47,15 +49,23 @@ export const BlocksControl: React.FC<BlocksControlProps> = ({
addBlock,
pinBlocksPopover,
flows,
nodes,
}) => {
const [searchQuery, setSearchQuery] = useState("");
const [selectedCategory, setSelectedCategory] = useState<string | null>(null);
const getFilteredBlockList = (): Block[] => {
const graphHasWebhookNodes = nodes.some(
(n) => n.data.uiType == BlockUIType.WEBHOOK,
);
const graphHasInputNodes = nodes.some(
(n) => n.data.uiType == BlockUIType.INPUT,
);
const filteredAvailableBlocks = useMemo(() => {
const blockList = blocks
.filter((b) => b.uiType !== BlockUIType.AGENT)
.sort((a, b) => a.name.localeCompare(b.name));
const agentList = flows.map(
const agentBlockList = flows.map(
(flow) =>
({
id: SpecialBlockID.AGENT,
@ -80,7 +90,7 @@ export const BlocksControl: React.FC<BlocksControlProps> = ({
);
return blockList
.concat(agentList)
.concat(agentBlockList)
.filter(
(block: Block) =>
(block.name.toLowerCase().includes(searchQuery.toLowerCase()) ||
@ -92,8 +102,29 @@ export const BlocksControl: React.FC<BlocksControlProps> = ({
.includes(searchQuery.toLowerCase())) &&
(!selectedCategory ||
block.categories.some((cat) => cat.category === selectedCategory)),
);
};
)
.map((block) => ({
...block,
notAvailable:
(block.uiType == BlockUIType.WEBHOOK &&
graphHasWebhookNodes &&
"Agents can only have one webhook-triggered block") ||
(block.uiType == BlockUIType.WEBHOOK &&
graphHasInputNodes &&
"Webhook-triggered blocks can't be used together with input blocks") ||
(block.uiType == BlockUIType.INPUT &&
graphHasWebhookNodes &&
"Input blocks can't be used together with a webhook-triggered block") ||
null,
}));
}, [
blocks,
flows,
searchQuery,
selectedCategory,
graphHasInputNodes,
graphHasWebhookNodes,
]);
const resetFilters = React.useCallback(() => {
setSearchQuery("");
@ -190,14 +221,20 @@ export const BlocksControl: React.FC<BlocksControlProps> = ({
className="h-[60vh]"
data-id="blocks-control-scroll-area"
>
{getFilteredBlockList().map((block) => (
{filteredAvailableBlocks.map((block) => (
<Card
key={block.uiKey || block.id}
className="m-2 my-4 flex h-20 cursor-pointer shadow-none hover:shadow-lg"
className={`m-2 my-4 flex h-20 shadow-none ${
block.notAvailable
? "cursor-not-allowed opacity-50"
: "cursor-pointer hover:shadow-lg"
}`}
data-id={`block-card-${block.id}`}
onClick={() =>
!block.notAvailable &&
addBlock(block.id, block.name, block?.hardcodedValues || {})
}
title={block.notAvailable ?? undefined}
>
<div
className={`-ml-px h-full w-3 rounded-l-xl ${getPrimaryCategoryColor(block.categories)}`}

View File

@ -308,6 +308,21 @@ export const NodeGenericInputField: FC<{
handleInputClick={handleInputClick}
/>
);
} else if (
(types.includes("integer") || types.includes("number")) &&
types.includes("null")
) {
return (
<NodeNumberInput
selfKey={propKey}
schema={{ ...propSchema, type: "integer" } as BlockIONumberSubSchema}
value={currentValue}
error={errors[propKey]}
className={className}
displayName={displayName}
handleInputChange={handleInputChange}
/>
);
}
}
@ -541,7 +556,7 @@ const NodeKeyValueInput: FC<{
>
<div>
{keyValuePairs.map(({ key, value }, index) => (
/*
/*
The `index` is used as a DOM key instead of the actual `key`
because the `key` can change with each input, causing the input to lose focus.
*/

View File

@ -172,7 +172,7 @@ export const startTutorial = (
text: "Please click the block button to open the blocks menu.",
attachTo: {
element: '[data-id="blocks-control-popover-trigger"]',
on: "bottom",
on: "right",
},
advanceOn: {
selector: '[data-id="blocks-control-popover-trigger"]',
@ -210,7 +210,7 @@ export const startTutorial = (
id: "focus-new-block",
title: "New Block",
text: "This is the Calculator Block! Let's go over how it works.",
attachTo: { element: `[data-id="custom-node-1"]`, on: "top" },
attachTo: { element: `[data-id="custom-node-1"]`, on: "left" },
beforeShowPromise: () => waitForElement('[data-id="custom-node-1"]'),
buttons: [
{
@ -308,7 +308,7 @@ export const startTutorial = (
text: "Enter a name for your agent, add an optional description, and then click 'Save agent' to save your flow.",
attachTo: {
element: '[data-id="save-control-popover-content"]',
on: "bottom",
on: "top",
},
buttons: [],
beforeShowPromise: () =>
@ -371,13 +371,14 @@ export const startTutorial = (
id: "check-output",
title: "Check the Output",
text: "Check here to see the output of the block after running the flow.",
attachTo: { element: '[data-id="latest-output"]', on: "bottom" },
beforeShowPromise: () => waitForElement('[data-id="latest-output"]'),
attachTo: { element: '[data-id="latest-output"]', on: "top" },
beforeShowPromise: () =>
new Promise((resolve) => {
setTimeout(() => {
waitForElement('[data-id="latest-output"]').then(resolve);
}, 100);
}),
buttons: [
{
text: "Back",
action: tour.back,
},
{
text: "Next",
action: tour.next,

View File

@ -121,6 +121,7 @@ from backend.data.model import (
from backend.data.block import Block, BlockOutput, BlockSchema
from backend.data.model import CredentialsField
from backend.integrations.providers import ProviderName
# API Key auth:
@ -128,9 +129,9 @@ class BlockWithAPIKeyAuth(Block):
class Input(BlockSchema):
# Note that the type hint below is require or you will get a type error.
# The first argument is the provider name, the second is the credential type.
credentials: CredentialsMetaInput[Literal['github'], Literal['api_key']] = CredentialsField(
provider="github",
supported_credential_types={"api_key"},
credentials: CredentialsMetaInput[
Literal[ProviderName.GITHUB], Literal["api_key"]
] = CredentialsField(
description="The GitHub integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
@ -151,9 +152,9 @@ class BlockWithOAuth(Block):
class Input(BlockSchema):
# Note that the type hint below is require or you will get a type error.
# The first argument is the provider name, the second is the credential type.
credentials: CredentialsMetaInput[Literal['github'], Literal['oauth2']] = CredentialsField(
provider="github",
supported_credential_types={"oauth2"},
credentials: CredentialsMetaInput[
Literal[ProviderName.GITHUB], Literal["oauth2"]
] = CredentialsField(
required_scopes={"repo"},
description="The GitHub integration can be used with OAuth.",
)
@ -174,9 +175,9 @@ class BlockWithAPIKeyAndOAuth(Block):
class Input(BlockSchema):
# Note that the type hint below is require or you will get a type error.
# The first argument is the provider name, the second is the credential type.
credentials: CredentialsMetaInput[Literal['github'], Literal['api_key', 'oauth2']] = CredentialsField(
provider="github",
supported_credential_types={"api_key", "oauth2"},
credentials: CredentialsMetaInput[
Literal[ProviderName.GITHUB], Literal["api_key", "oauth2"]
] = CredentialsField(
required_scopes={"repo"},
description="The GitHub integration can be used with OAuth, "
"or any API key with sufficient permissions for the blocks it is used on.",
@ -227,6 +228,16 @@ response = requests.post(
)
```
The `ProviderName` enum is the single source of truth for which providers exist in our system.
Naturally, to add an authenticated block for a new provider, you'll have to add it here too.
<details>
<summary><code>ProviderName</code> definition</summary>
```python title="backend/integrations/providers.py"
--8<-- "autogpt_platform/backend/backend/integrations/providers.py:ProviderName"
```
</details>
#### Adding an OAuth2 service integration
To add support for a new OAuth2-authenticated service, you'll need to add an `OAuthHandler`.