Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into kpczerwinski/secrt-1012-mvp-implement-top-up-flow

This commit is contained in:
Zamil Majdy 2025-01-10 18:03:36 -06:00
commit c7e994cb69
113 changed files with 15603 additions and 1747 deletions

View File

@ -0,0 +1,76 @@
from typing import Annotated, Any, Literal, Optional, TypedDict
from uuid import uuid4
from pydantic import BaseModel, Field, SecretStr, field_serializer
class _BaseCredentials(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
provider: str
title: Optional[str]
@field_serializer("*")
def dump_secret_strings(value: Any, _info):
if isinstance(value, SecretStr):
return value.get_secret_value()
return value
class OAuth2Credentials(_BaseCredentials):
type: Literal["oauth2"] = "oauth2"
username: Optional[str]
"""Username of the third-party service user that these credentials belong to"""
access_token: SecretStr
access_token_expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
refresh_token: Optional[SecretStr]
refresh_token_expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
scopes: list[str]
metadata: dict[str, Any] = Field(default_factory=dict)
def bearer(self) -> str:
return f"Bearer {self.access_token.get_secret_value()}"
class APIKeyCredentials(_BaseCredentials):
type: Literal["api_key"] = "api_key"
api_key: SecretStr
expires_at: Optional[int]
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
def bearer(self) -> str:
return f"Bearer {self.api_key.get_secret_value()}"
Credentials = Annotated[
OAuth2Credentials | APIKeyCredentials,
Field(discriminator="type"),
]
CredentialsType = Literal["api_key", "oauth2"]
class OAuthState(BaseModel):
token: str
provider: str
expires_at: int
code_verifier: Optional[str] = None
scopes: list[str]
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
class UserMetadata(BaseModel):
integration_credentials: list[Credentials] = Field(default_factory=list)
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict]
integration_oauth_states: list[dict]
class UserIntegrations(BaseModel):
credentials: list[Credentials] = Field(default_factory=list)
oauth_states: list[OAuthState] = Field(default_factory=list)

View File

@ -61,6 +61,21 @@ GITHUB_CLIENT_SECRET=
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
# Twitter (X) OAuth 2.0 with PKCE Configuration
# 1. Create a Twitter Developer Account:
# - Visit https://developer.x.com/en and sign up
# 2. Set up your application:
# - Navigate to Developer Portal > Projects > Create Project
# - Add a new app to your project
# 3. Configure app settings:
# - App Permissions: Read + Write + Direct Messages
# - App Type: Web App, Automated App or Bot
# - OAuth 2.0 Callback URL: http://localhost:3000/auth/integrations/oauth_callback
# - Save your Client ID and Client Secret below
TWITTER_CLIENT_ID=
TWITTER_CLIENT_SECRET=
## ===== OPTIONAL API KEYS ===== ##
# LLM
@ -109,6 +124,18 @@ REPLICATE_API_KEY=
# Ideogram
IDEOGRAM_API_KEY=
# Fal
FAL_API_KEY=
# Exa
EXA_API_KEY=
# E2B
E2B_API_KEY=
# Nvidia
NVIDIA_API_KEY=
# Logging Configuration
LOG_LEVEL=INFO
ENABLE_CLOUD_LOGGING=false

View File

@ -241,7 +241,7 @@ class AgentOutputBlock(Block):
advanced=True,
)
format: str = SchemaField(
description="The format string to be used to format the recorded_value.",
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
default="",
advanced=True,
)

View File

@ -56,15 +56,24 @@ class SendWebRequestBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.body, str):
input_data.body = json.loads(input_data.body)
body = input_data.body
if input_data.json_format:
if isinstance(body, str):
try:
# Try to parse as JSON first
body = json.loads(body)
except json.JSONDecodeError:
# If it's not valid JSON and just plain text,
# we should send it as plain text instead
input_data.json_format = False
response = requests.request(
input_data.method.value,
input_data.url,
headers=input_data.headers,
json=input_data.body if input_data.json_format else None,
data=input_data.body if not input_data.json_format else None,
json=body if input_data.json_format else None,
data=body if not input_data.json_format else None,
)
result = response.json() if input_data.json_format else response.text

View File

@ -26,8 +26,10 @@ from backend.data.model import (
)
from backend.util import json
from backend.util.settings import BehaveAs, Settings
from backend.util.text import TextFormatter
logger = logging.getLogger(__name__)
fmt = TextFormatter()
LLMProviderName = Literal[
ProviderName.ANTHROPIC,
@ -109,6 +111,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
LLAMA3_1_70B = "llama-3.1-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
# Ollama models
OLLAMA_LLAMA3_2 = "llama3.2"
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
@ -163,6 +166,7 @@ MODEL_METADATA = {
# Limited to 16k during preview
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
@ -234,7 +238,9 @@ class AIStructuredResponseGeneratorBlock(Block):
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
advanced=False,
default={},
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
max_tokens: int | None = SchemaField(
advanced=True,
@ -448,8 +454,8 @@ class AIStructuredResponseGeneratorBlock(Block):
values = input_data.prompt_values
if values:
input_data.prompt = input_data.prompt.format(**values)
input_data.sys_prompt = input_data.sys_prompt.format(**values)
input_data.prompt = fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
if input_data.sys_prompt:
prompt.append({"role": "system", "content": input_data.sys_prompt})
@ -576,7 +582,9 @@ class AITextGeneratorBlock(Block):
description="Number of times to retry the LLM call if the response does not match the expected format.",
)
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
advanced=False,
default={},
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
)
ollama_host: str = SchemaField(
advanced=True,

View File

@ -0,0 +1,32 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
NvidiaCredentials = APIKeyCredentials
NvidiaCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.NVIDIA],
Literal["api_key"],
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="nvidia",
api_key=SecretStr("mock-nvidia-api-key"),
title="Mock Nvidia API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
def NvidiaCredentialsField() -> NvidiaCredentialsInput:
"""Creates an Nvidia credentials input on a block."""
return CredentialsField(description="The Nvidia integration requires an API Key.")

View File

@ -0,0 +1,90 @@
from backend.blocks.nvidia._auth import (
NvidiaCredentials,
NvidiaCredentialsField,
NvidiaCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.request import requests
class NvidiaDeepfakeDetectBlock(Block):
class Input(BlockSchema):
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
image_base64: str = SchemaField(
description="Image to analyze for deepfakes", image_upload=True
)
return_image: bool = SchemaField(
description="Whether to return the processed image with markings",
default=False,
)
class Output(BlockSchema):
status: str = SchemaField(
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
default="",
)
image: str = SchemaField(
description="Processed image with detection markings (if return_image=True)",
default="",
image_output=True,
)
is_deepfake: float = SchemaField(
description="Probability that the image is a deepfake (0-1)",
default=0.0,
)
def __init__(self):
super().__init__(
id="8c7d0d67-e79c-44f6-92a1-c2600c8aac7f",
description="Detects potential deepfakes in images using Nvidia's AI API",
categories={BlockCategory.SAFETY},
input_schema=NvidiaDeepfakeDetectBlock.Input,
output_schema=NvidiaDeepfakeDetectBlock.Output,
)
def run(
self, input_data: Input, *, credentials: NvidiaCredentials, **kwargs
) -> BlockOutput:
url = "https://ai.api.nvidia.com/v1/cv/hive/deepfake-image-detection"
headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
}
image_data = f"data:image/jpeg;base64,{input_data.image_base64}"
payload = {
"input": [image_data],
"return_image": input_data.return_image,
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
result = data.get("data", [{}])[0]
# Get deepfake probability from first bounding box if any
deepfake_prob = 0.0
if result.get("bounding_boxes"):
deepfake_prob = result["bounding_boxes"][0].get("is_deepfake", 0.0)
yield "status", result.get("status", "ERROR")
yield "is_deepfake", deepfake_prob
if input_data.return_image:
image_data = result.get("image", "")
output_data = f"data:image/jpeg;base64,{image_data}"
yield "image", output_data
else:
yield "image", ""
except Exception as e:
yield "error", str(e)
yield "status", "ERROR"
yield "is_deepfake", 0.0
yield "image", ""

View File

@ -141,10 +141,10 @@ class ExtractTextInformationBlock(Block):
class FillTextTemplateBlock(Block):
class Input(BlockSchema):
values: dict[str, Any] = SchemaField(
description="Values (dict) to be used in format"
description="Values (dict) to be used in format. These values can be used by putting them in double curly braces in the format template. e.g. {{value_name}}.",
)
format: str = SchemaField(
description="Template to format the text using `values`"
description="Template to format the text using `values`. Use Jinja2 syntax."
)
class Output(BlockSchema):
@ -160,7 +160,7 @@ class FillTextTemplateBlock(Block):
test_input=[
{
"values": {"name": "Alice", "hello": "Hello", "world": "World!"},
"format": "{hello}, {world} {{name}}",
"format": "{{hello}}, {{ world }} {{name}}",
},
{
"values": {"list": ["Hello", " World!"]},

View File

@ -0,0 +1,60 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
ProviderName,
)
from backend.integrations.oauth.twitter import TwitterOAuthHandler
from backend.util.settings import Secrets
# --8<-- [start:TwitterOAuthIsConfigured]
secrets = Secrets()
TWITTER_OAUTH_IS_CONFIGURED = bool(
secrets.twitter_client_id and secrets.twitter_client_secret
)
# --8<-- [end:TwitterOAuthIsConfigured]
TwitterCredentials = OAuth2Credentials
TwitterCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.TWITTER], Literal["oauth2"]
]
# Currently, We are getting all the permission from the Twitter API initally
# In future, If we need to add incremental permission, we can use these requested_scopes
def TwitterCredentialsField(scopes: list[str]) -> TwitterCredentialsInput:
"""
Creates a Twitter credentials input on a block.
Params:
scopes: The authorization scopes needed for the block to work.
"""
return CredentialsField(
# required_scopes=set(scopes),
required_scopes=set(TwitterOAuthHandler.DEFAULT_SCOPES + scopes),
description="The Twitter integration requires OAuth2 authentication.",
)
TEST_CREDENTIALS = OAuth2Credentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="twitter",
access_token=SecretStr("mock-twitter-access-token"),
refresh_token=SecretStr("mock-twitter-refresh-token"),
access_token_expires_at=1234567890,
scopes=["tweet.read", "tweet.write", "users.read", "offline.access"],
title="Mock Twitter OAuth2 Credentials",
username="mock-twitter-username",
refresh_token_expires_at=1234567890,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

View File

@ -0,0 +1,418 @@
from datetime import datetime
from typing import Any, Dict
from backend.blocks.twitter._mappers import (
get_backend_expansion,
get_backend_field,
get_backend_list_expansion,
get_backend_list_field,
get_backend_media_field,
get_backend_place_field,
get_backend_poll_field,
get_backend_space_expansion,
get_backend_space_field,
get_backend_user_field,
)
from backend.blocks.twitter._types import ( # DMEventFieldFilter,
DMEventExpansionFilter,
DMEventTypeFilter,
DMMediaFieldFilter,
DMTweetFieldFilter,
ExpansionFilter,
ListExpansionsFilter,
ListFieldsFilter,
SpaceExpansionsFilter,
SpaceFieldsFilter,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetReplySettingsFilter,
TweetUserFieldsFilter,
UserExpansionsFilter,
)
# Common Builder
class TweetExpansionsBuilder:
def __init__(self, param: Dict[str, Any]):
self.params: Dict[str, Any] = param
def add_expansions(self, expansions: ExpansionFilter | None):
if expansions:
filtered_expansions = [
name for name, value in expansions.dict().items() if value is True
]
if filtered_expansions:
self.params["expansions"] = ",".join(
[get_backend_expansion(exp) for exp in filtered_expansions]
)
return self
def add_media_fields(self, media_fields: TweetMediaFieldsFilter | None):
if media_fields:
filtered_fields = [
name for name, value in media_fields.dict().items() if value is True
]
if filtered_fields:
self.params["media.fields"] = ",".join(
[get_backend_media_field(field) for field in filtered_fields]
)
return self
def add_place_fields(self, place_fields: TweetPlaceFieldsFilter | None):
if place_fields:
filtered_fields = [
name for name, value in place_fields.dict().items() if value is True
]
if filtered_fields:
self.params["place.fields"] = ",".join(
[get_backend_place_field(field) for field in filtered_fields]
)
return self
def add_poll_fields(self, poll_fields: TweetPollFieldsFilter | None):
if poll_fields:
filtered_fields = [
name for name, value in poll_fields.dict().items() if value is True
]
if filtered_fields:
self.params["poll.fields"] = ",".join(
[get_backend_poll_field(field) for field in filtered_fields]
)
return self
def add_tweet_fields(self, tweet_fields: TweetFieldsFilter | None):
if tweet_fields:
filtered_fields = [
name for name, value in tweet_fields.dict().items() if value is True
]
if filtered_fields:
self.params["tweet.fields"] = ",".join(
[get_backend_field(field) for field in filtered_fields]
)
return self
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
if user_fields:
filtered_fields = [
name for name, value in user_fields.dict().items() if value is True
]
if filtered_fields:
self.params["user.fields"] = ",".join(
[get_backend_user_field(field) for field in filtered_fields]
)
return self
def build(self):
return self.params
class UserExpansionsBuilder:
def __init__(self, param: Dict[str, Any]):
self.params: Dict[str, Any] = param
def add_expansions(self, expansions: UserExpansionsFilter | None):
if expansions:
filtered_expansions = [
name for name, value in expansions.dict().items() if value is True
]
if filtered_expansions:
self.params["expansions"] = ",".join(filtered_expansions)
return self
def add_tweet_fields(self, tweet_fields: TweetFieldsFilter | None):
if tweet_fields:
filtered_fields = [
name for name, value in tweet_fields.dict().items() if value is True
]
if filtered_fields:
self.params["tweet.fields"] = ",".join(
[get_backend_field(field) for field in filtered_fields]
)
return self
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
if user_fields:
filtered_fields = [
name for name, value in user_fields.dict().items() if value is True
]
if filtered_fields:
self.params["user.fields"] = ",".join(
[get_backend_user_field(field) for field in filtered_fields]
)
return self
def build(self):
return self.params
class ListExpansionsBuilder:
def __init__(self, param: Dict[str, Any]):
self.params: Dict[str, Any] = param
def add_expansions(self, expansions: ListExpansionsFilter | None):
if expansions:
filtered_expansions = [
name for name, value in expansions.dict().items() if value is True
]
if filtered_expansions:
self.params["expansions"] = ",".join(
[get_backend_list_expansion(exp) for exp in filtered_expansions]
)
return self
def add_list_fields(self, list_fields: ListFieldsFilter | None):
if list_fields:
filtered_fields = [
name for name, value in list_fields.dict().items() if value is True
]
if filtered_fields:
self.params["list.fields"] = ",".join(
[get_backend_list_field(field) for field in filtered_fields]
)
return self
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
if user_fields:
filtered_fields = [
name for name, value in user_fields.dict().items() if value is True
]
if filtered_fields:
self.params["user.fields"] = ",".join(
[get_backend_user_field(field) for field in filtered_fields]
)
return self
def build(self):
return self.params
class SpaceExpansionsBuilder:
def __init__(self, param: Dict[str, Any]):
self.params: Dict[str, Any] = param
def add_expansions(self, expansions: SpaceExpansionsFilter | None):
if expansions:
filtered_expansions = [
name for name, value in expansions.dict().items() if value is True
]
if filtered_expansions:
self.params["expansions"] = ",".join(
[get_backend_space_expansion(exp) for exp in filtered_expansions]
)
return self
def add_space_fields(self, space_fields: SpaceFieldsFilter | None):
if space_fields:
filtered_fields = [
name for name, value in space_fields.dict().items() if value is True
]
if filtered_fields:
self.params["space.fields"] = ",".join(
[get_backend_space_field(field) for field in filtered_fields]
)
return self
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
if user_fields:
filtered_fields = [
name for name, value in user_fields.dict().items() if value is True
]
if filtered_fields:
self.params["user.fields"] = ",".join(
[get_backend_user_field(field) for field in filtered_fields]
)
return self
def build(self):
return self.params
class TweetDurationBuilder:
def __init__(self, param: Dict[str, Any]):
self.params: Dict[str, Any] = param
def add_start_time(self, start_time: datetime | None):
if start_time:
self.params["start_time"] = start_time
return self
def add_end_time(self, end_time: datetime | None):
if end_time:
self.params["end_time"] = end_time
return self
def add_since_id(self, since_id: str | None):
if since_id:
self.params["since_id"] = since_id
return self
def add_until_id(self, until_id: str | None):
if until_id:
self.params["until_id"] = until_id
return self
def add_sort_order(self, sort_order: str | None):
if sort_order:
self.params["sort_order"] = sort_order
return self
def build(self):
return self.params
class DMExpansionsBuilder:
def __init__(self, param: Dict[str, Any]):
self.params: Dict[str, Any] = param
def add_expansions(self, expansions: DMEventExpansionFilter):
if expansions:
filtered_expansions = [
name for name, value in expansions.dict().items() if value is True
]
if filtered_expansions:
self.params["expansions"] = ",".join(filtered_expansions)
return self
def add_event_types(self, event_types: DMEventTypeFilter):
if event_types:
filtered_types = [
name for name, value in event_types.dict().items() if value is True
]
if filtered_types:
self.params["event_types"] = ",".join(filtered_types)
return self
def add_media_fields(self, media_fields: DMMediaFieldFilter):
if media_fields:
filtered_fields = [
name for name, value in media_fields.dict().items() if value is True
]
if filtered_fields:
self.params["media.fields"] = ",".join(filtered_fields)
return self
def add_tweet_fields(self, tweet_fields: DMTweetFieldFilter):
if tweet_fields:
filtered_fields = [
name for name, value in tweet_fields.dict().items() if value is True
]
if filtered_fields:
self.params["tweet.fields"] = ",".join(filtered_fields)
return self
def add_user_fields(self, user_fields: TweetUserFieldsFilter):
if user_fields:
filtered_fields = [
name for name, value in user_fields.dict().items() if value is True
]
if filtered_fields:
self.params["user.fields"] = ",".join(filtered_fields)
return self
def build(self):
return self.params
# Specific Builders
class TweetSearchBuilder:
def __init__(self):
self.params: Dict[str, Any] = {"user_auth": False}
def add_query(self, query: str):
if query:
self.params["query"] = query
return self
def add_pagination(self, max_results: int, pagination: str | None):
if max_results:
self.params["max_results"] = max_results
if pagination:
self.params["pagination_token"] = pagination
return self
def build(self):
return self.params
class TweetPostBuilder:
def __init__(self):
self.params: Dict[str, Any] = {"user_auth": False}
def add_text(self, text: str | None):
if text:
self.params["text"] = text
return self
def add_media(self, media_ids: list, tagged_user_ids: list):
if media_ids:
self.params["media_ids"] = media_ids
if tagged_user_ids:
self.params["media_tagged_user_ids"] = tagged_user_ids
return self
def add_deep_link(self, link: str):
if link:
self.params["direct_message_deep_link"] = link
return self
def add_super_followers(self, for_super_followers: bool):
if for_super_followers:
self.params["for_super_followers_only"] = for_super_followers
return self
def add_place(self, place_id: str):
if place_id:
self.params["place_id"] = place_id
return self
def add_poll_options(self, poll_options: list):
if poll_options:
self.params["poll_options"] = poll_options
return self
def add_poll_duration(self, poll_duration_minutes: int):
if poll_duration_minutes:
self.params["poll_duration_minutes"] = poll_duration_minutes
return self
def add_quote(self, quote_id: str):
if quote_id:
self.params["quote_tweet_id"] = quote_id
return self
def add_reply_settings(
self,
exclude_user_ids: list,
reply_to_id: str,
settings: TweetReplySettingsFilter,
):
if exclude_user_ids:
self.params["exclude_reply_user_ids"] = exclude_user_ids
if reply_to_id:
self.params["in_reply_to_tweet_id"] = reply_to_id
if settings.All_Users:
self.params["reply_settings"] = None
elif settings.Following_Users_Only:
self.params["reply_settings"] = "following"
elif settings.Mentioned_Users_Only:
self.params["reply_settings"] = "mentionedUsers"
return self
def build(self):
return self.params
class TweetGetsBuilder:
def __init__(self):
self.params: Dict[str, Any] = {"user_auth": False}
def add_id(self, tweet_id: list[str]):
self.params["id"] = tweet_id
return self
def build(self):
return self.params

View File

@ -0,0 +1,234 @@
# -------------- Tweets -----------------
# Tweet Expansions
EXPANSION_FRONTEND_TO_BACKEND_MAPPING = {
"Poll_IDs": "attachments.poll_ids",
"Media_Keys": "attachments.media_keys",
"Author_User_ID": "author_id",
"Edit_History_Tweet_IDs": "edit_history_tweet_ids",
"Mentioned_Usernames": "entities.mentions.username",
"Place_ID": "geo.place_id",
"Reply_To_User_ID": "in_reply_to_user_id",
"Referenced_Tweet_ID": "referenced_tweets.id",
"Referenced_Tweet_Author_ID": "referenced_tweets.id.author_id",
}
def get_backend_expansion(frontend_key: str) -> str:
result = EXPANSION_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid expansion key: {frontend_key}")
return result
# TweetReplySettings
REPLY_SETTINGS_FRONTEND_TO_BACKEND_MAPPING = {
"Mentioned_Users_Only": "mentionedUsers",
"Following_Users_Only": "following",
"All_Users": "all",
}
# TweetUserFields
def get_backend_reply_setting(frontend_key: str) -> str:
result = REPLY_SETTINGS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid reply setting key: {frontend_key}")
return result
USER_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"Account_Creation_Date": "created_at",
"User_Bio": "description",
"User_Entities": "entities",
"User_ID": "id",
"User_Location": "location",
"Latest_Tweet_ID": "most_recent_tweet_id",
"Display_Name": "name",
"Pinned_Tweet_ID": "pinned_tweet_id",
"Profile_Picture_URL": "profile_image_url",
"Is_Protected_Account": "protected",
"Account_Statistics": "public_metrics",
"Profile_URL": "url",
"Username": "username",
"Is_Verified": "verified",
"Verification_Type": "verified_type",
"Content_Withholding_Info": "withheld",
}
def get_backend_user_field(frontend_key: str) -> str:
result = USER_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid user field key: {frontend_key}")
return result
# TweetFields
FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"Tweet_Attachments": "attachments",
"Author_ID": "author_id",
"Context_Annotations": "context_annotations",
"Conversation_ID": "conversation_id",
"Creation_Time": "created_at",
"Edit_Controls": "edit_controls",
"Tweet_Entities": "entities",
"Geographic_Location": "geo",
"Tweet_ID": "id",
"Reply_To_User_ID": "in_reply_to_user_id",
"Language": "lang",
"Public_Metrics": "public_metrics",
"Sensitive_Content_Flag": "possibly_sensitive",
"Referenced_Tweets": "referenced_tweets",
"Reply_Settings": "reply_settings",
"Tweet_Source": "source",
"Tweet_Text": "text",
"Withheld_Content": "withheld",
}
def get_backend_field(frontend_key: str) -> str:
result = FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid field key: {frontend_key}")
return result
# TweetPollFields
POLL_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"Duration_Minutes": "duration_minutes",
"End_DateTime": "end_datetime",
"Poll_ID": "id",
"Poll_Options": "options",
"Voting_Status": "voting_status",
}
def get_backend_poll_field(frontend_key: str) -> str:
result = POLL_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid poll field key: {frontend_key}")
return result
PLACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"Contained_Within_Places": "contained_within",
"Country": "country",
"Country_Code": "country_code",
"Full_Location_Name": "full_name",
"Geographic_Coordinates": "geo",
"Place_ID": "id",
"Place_Name": "name",
"Place_Type": "place_type",
}
def get_backend_place_field(frontend_key: str) -> str:
result = PLACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid place field key: {frontend_key}")
return result
# TweetMediaFields
MEDIA_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"Duration_in_Milliseconds": "duration_ms",
"Height": "height",
"Media_Key": "media_key",
"Preview_Image_URL": "preview_image_url",
"Media_Type": "type",
"Media_URL": "url",
"Width": "width",
"Public_Metrics": "public_metrics",
"Non_Public_Metrics": "non_public_metrics",
"Organic_Metrics": "organic_metrics",
"Promoted_Metrics": "promoted_metrics",
"Alternative_Text": "alt_text",
"Media_Variants": "variants",
}
def get_backend_media_field(frontend_key: str) -> str:
result = MEDIA_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid media field key: {frontend_key}")
return result
# -------------- Spaces -----------------
# SpaceExpansions
EXPANSION_FRONTEND_TO_BACKEND_MAPPING_SPACE = {
"Invited_Users": "invited_user_ids",
"Speakers": "speaker_ids",
"Creator": "creator_id",
"Hosts": "host_ids",
"Topics": "topic_ids",
}
def get_backend_space_expansion(frontend_key: str) -> str:
result = EXPANSION_FRONTEND_TO_BACKEND_MAPPING_SPACE.get(frontend_key)
if result is None:
raise KeyError(f"Invalid expansion key: {frontend_key}")
return result
# SpaceFields
SPACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"Space_ID": "id",
"Space_State": "state",
"Creation_Time": "created_at",
"End_Time": "ended_at",
"Host_User_IDs": "host_ids",
"Language": "lang",
"Is_Ticketed": "is_ticketed",
"Invited_User_IDs": "invited_user_ids",
"Participant_Count": "participant_count",
"Subscriber_Count": "subscriber_count",
"Scheduled_Start_Time": "scheduled_start",
"Speaker_User_IDs": "speaker_ids",
"Start_Time": "started_at",
"Space_Title": "title",
"Topic_IDs": "topic_ids",
"Last_Updated_Time": "updated_at",
}
def get_backend_space_field(frontend_key: str) -> str:
result = SPACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid space field key: {frontend_key}")
return result
# -------------- List Expansions -----------------
# ListExpansions
LIST_EXPANSION_FRONTEND_TO_BACKEND_MAPPING = {"List_Owner_ID": "owner_id"}
def get_backend_list_expansion(frontend_key: str) -> str:
result = LIST_EXPANSION_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid list expansion key: {frontend_key}")
return result
LIST_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
"List_ID": "id",
"List_Name": "name",
"Creation_Date": "created_at",
"Description": "description",
"Follower_Count": "follower_count",
"Member_Count": "member_count",
"Is_Private": "private",
"Owner_ID": "owner_id",
}
def get_backend_list_field(frontend_key: str) -> str:
result = LIST_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
if result is None:
raise KeyError(f"Invalid list field key: {frontend_key}")
return result

View File

@ -0,0 +1,76 @@
from typing import Any, Dict, List
class BaseSerializer:
@staticmethod
def _serialize_value(value: Any) -> Any:
"""Helper method to serialize individual values"""
if hasattr(value, "data"):
return value.data
return value
class IncludesSerializer(BaseSerializer):
@classmethod
def serialize(cls, includes: Dict[str, Any]) -> Dict[str, Any]:
"""Serializes the includes dictionary"""
if not includes:
return {}
serialized_includes = {}
for key, value in includes.items():
if isinstance(value, list):
serialized_includes[key] = [
cls._serialize_value(item) for item in value
]
else:
serialized_includes[key] = cls._serialize_value(value)
return serialized_includes
class ResponseDataSerializer(BaseSerializer):
@classmethod
def serialize_dict(cls, item: Dict[str, Any]) -> Dict[str, Any]:
"""Serializes a single dictionary item"""
serialized_item = {}
if hasattr(item, "__dict__"):
items = item.__dict__.items()
else:
items = item.items()
for key, value in items:
if isinstance(value, list):
serialized_item[key] = [
cls._serialize_value(sub_item) for sub_item in value
]
else:
serialized_item[key] = cls._serialize_value(value)
return serialized_item
@classmethod
def serialize_list(cls, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Serializes a list of dictionary items"""
return [cls.serialize_dict(item) for item in data]
class ResponseSerializer:
@classmethod
def serialize(cls, response) -> Dict[str, Any]:
"""Main serializer that handles both data and includes"""
result = {"data": None, "included": {}}
# Handle response.data
if response.data:
if isinstance(response.data, list):
result["data"] = ResponseDataSerializer.serialize_list(response.data)
else:
result["data"] = ResponseDataSerializer.serialize_dict(response.data)
# Handle includes
if hasattr(response, "includes") and response.includes:
result["included"] = IncludesSerializer.serialize(response.includes)
return result

View File

@ -0,0 +1,443 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from backend.data.block import BlockSchema
from backend.data.model import SchemaField
# -------------- Tweets -----------------
class TweetReplySettingsFilter(BaseModel):
Mentioned_Users_Only: bool = False
Following_Users_Only: bool = False
All_Users: bool = False
class TweetUserFieldsFilter(BaseModel):
Account_Creation_Date: bool = False
User_Bio: bool = False
User_Entities: bool = False
User_ID: bool = False
User_Location: bool = False
Latest_Tweet_ID: bool = False
Display_Name: bool = False
Pinned_Tweet_ID: bool = False
Profile_Picture_URL: bool = False
Is_Protected_Account: bool = False
Account_Statistics: bool = False
Profile_URL: bool = False
Username: bool = False
Is_Verified: bool = False
Verification_Type: bool = False
Content_Withholding_Info: bool = False
class TweetFieldsFilter(BaseModel):
Tweet_Attachments: bool = False
Author_ID: bool = False
Context_Annotations: bool = False
Conversation_ID: bool = False
Creation_Time: bool = False
Edit_Controls: bool = False
Tweet_Entities: bool = False
Geographic_Location: bool = False
Tweet_ID: bool = False
Reply_To_User_ID: bool = False
Language: bool = False
Public_Metrics: bool = False
Sensitive_Content_Flag: bool = False
Referenced_Tweets: bool = False
Reply_Settings: bool = False
Tweet_Source: bool = False
Tweet_Text: bool = False
Withheld_Content: bool = False
class PersonalTweetFieldsFilter(BaseModel):
attachments: bool = False
author_id: bool = False
context_annotations: bool = False
conversation_id: bool = False
created_at: bool = False
edit_controls: bool = False
entities: bool = False
geo: bool = False
id: bool = False
in_reply_to_user_id: bool = False
lang: bool = False
non_public_metrics: bool = False
public_metrics: bool = False
organic_metrics: bool = False
promoted_metrics: bool = False
possibly_sensitive: bool = False
referenced_tweets: bool = False
reply_settings: bool = False
source: bool = False
text: bool = False
withheld: bool = False
class TweetPollFieldsFilter(BaseModel):
Duration_Minutes: bool = False
End_DateTime: bool = False
Poll_ID: bool = False
Poll_Options: bool = False
Voting_Status: bool = False
class TweetPlaceFieldsFilter(BaseModel):
Contained_Within_Places: bool = False
Country: bool = False
Country_Code: bool = False
Full_Location_Name: bool = False
Geographic_Coordinates: bool = False
Place_ID: bool = False
Place_Name: bool = False
Place_Type: bool = False
class TweetMediaFieldsFilter(BaseModel):
Duration_in_Milliseconds: bool = False
Height: bool = False
Media_Key: bool = False
Preview_Image_URL: bool = False
Media_Type: bool = False
Media_URL: bool = False
Width: bool = False
Public_Metrics: bool = False
Non_Public_Metrics: bool = False
Organic_Metrics: bool = False
Promoted_Metrics: bool = False
Alternative_Text: bool = False
Media_Variants: bool = False
class ExpansionFilter(BaseModel):
Poll_IDs: bool = False
Media_Keys: bool = False
Author_User_ID: bool = False
Edit_History_Tweet_IDs: bool = False
Mentioned_Usernames: bool = False
Place_ID: bool = False
Reply_To_User_ID: bool = False
Referenced_Tweet_ID: bool = False
Referenced_Tweet_Author_ID: bool = False
class TweetExcludesFilter(BaseModel):
retweets: bool = False
replies: bool = False
# -------------- Users -----------------
class UserExpansionsFilter(BaseModel):
pinned_tweet_id: bool = False
# -------------- DM's' -----------------
class DMEventFieldFilter(BaseModel):
id: bool = False
text: bool = False
event_type: bool = False
created_at: bool = False
dm_conversation_id: bool = False
sender_id: bool = False
participant_ids: bool = False
referenced_tweets: bool = False
attachments: bool = False
class DMEventTypeFilter(BaseModel):
MessageCreate: bool = False
ParticipantsJoin: bool = False
ParticipantsLeave: bool = False
class DMEventExpansionFilter(BaseModel):
attachments_media_keys: bool = False
referenced_tweets_id: bool = False
sender_id: bool = False
participant_ids: bool = False
class DMMediaFieldFilter(BaseModel):
duration_ms: bool = False
height: bool = False
media_key: bool = False
preview_image_url: bool = False
type: bool = False
url: bool = False
width: bool = False
public_metrics: bool = False
alt_text: bool = False
variants: bool = False
class DMTweetFieldFilter(BaseModel):
attachments: bool = False
author_id: bool = False
context_annotations: bool = False
conversation_id: bool = False
created_at: bool = False
edit_controls: bool = False
entities: bool = False
geo: bool = False
id: bool = False
in_reply_to_user_id: bool = False
lang: bool = False
public_metrics: bool = False
possibly_sensitive: bool = False
referenced_tweets: bool = False
reply_settings: bool = False
source: bool = False
text: bool = False
withheld: bool = False
# -------------- Spaces -----------------
class SpaceExpansionsFilter(BaseModel):
Invited_Users: bool = False
Speakers: bool = False
Creator: bool = False
Hosts: bool = False
Topics: bool = False
class SpaceFieldsFilter(BaseModel):
Space_ID: bool = False
Space_State: bool = False
Creation_Time: bool = False
End_Time: bool = False
Host_User_IDs: bool = False
Language: bool = False
Is_Ticketed: bool = False
Invited_User_IDs: bool = False
Participant_Count: bool = False
Subscriber_Count: bool = False
Scheduled_Start_Time: bool = False
Speaker_User_IDs: bool = False
Start_Time: bool = False
Space_Title: bool = False
Topic_IDs: bool = False
Last_Updated_Time: bool = False
class SpaceStatesFilter(str, Enum):
live = "live"
scheduled = "scheduled"
all = "all"
# -------------- List Expansions -----------------
class ListExpansionsFilter(BaseModel):
List_Owner_ID: bool = False
class ListFieldsFilter(BaseModel):
List_ID: bool = False
List_Name: bool = False
Creation_Date: bool = False
Description: bool = False
Follower_Count: bool = False
Member_Count: bool = False
Is_Private: bool = False
Owner_ID: bool = False
# --------- [Input Types] -------------
class TweetExpansionInputs(BlockSchema):
expansions: ExpansionFilter | None = SchemaField(
description="Choose what extra information you want to get with your tweets. For example:\n- Select 'Media_Keys' to get media details\n- Select 'Author_User_ID' to get user information\n- Select 'Place_ID' to get location details",
placeholder="Pick the extra information you want to see",
default=None,
advanced=True,
)
media_fields: TweetMediaFieldsFilter | None = SchemaField(
description="Select what media information you want to see (images, videos, etc). To use this, you must first select 'Media_Keys' in the expansions above.",
placeholder="Choose what media details you want to see",
default=None,
advanced=True,
)
place_fields: TweetPlaceFieldsFilter | None = SchemaField(
description="Select what location information you want to see (country, coordinates, etc). To use this, you must first select 'Place_ID' in the expansions above.",
placeholder="Choose what location details you want to see",
default=None,
advanced=True,
)
poll_fields: TweetPollFieldsFilter | None = SchemaField(
description="Select what poll information you want to see (options, voting status, etc). To use this, you must first select 'Poll_IDs' in the expansions above.",
placeholder="Choose what poll details you want to see",
default=None,
advanced=True,
)
tweet_fields: TweetFieldsFilter | None = SchemaField(
description="Select what tweet information you want to see. For referenced tweets (like retweets), select 'Referenced_Tweet_ID' in the expansions above.",
placeholder="Choose what tweet details you want to see",
default=None,
advanced=True,
)
user_fields: TweetUserFieldsFilter | None = SchemaField(
description="Select what user information you want to see. To use this, you must first select one of these in expansions above:\n- 'Author_User_ID' for tweet authors\n- 'Mentioned_Usernames' for mentioned users\n- 'Reply_To_User_ID' for users being replied to\n- 'Referenced_Tweet_Author_ID' for authors of referenced tweets",
placeholder="Choose what user details you want to see",
default=None,
advanced=True,
)
class DMEventExpansionInputs(BlockSchema):
expansions: DMEventExpansionFilter | None = SchemaField(
description="Select expansions to include related data objects in the 'includes' section.",
placeholder="Enter expansions",
default=None,
advanced=True,
)
event_types: DMEventTypeFilter | None = SchemaField(
description="Select DM event types to include in the response.",
placeholder="Enter event types",
default=None,
advanced=True,
)
media_fields: DMMediaFieldFilter | None = SchemaField(
description="Select media fields to include in the response (requires expansions=attachments.media_keys).",
placeholder="Enter media fields",
default=None,
advanced=True,
)
tweet_fields: DMTweetFieldFilter | None = SchemaField(
description="Select tweet fields to include in the response (requires expansions=referenced_tweets.id).",
placeholder="Enter tweet fields",
default=None,
advanced=True,
)
user_fields: TweetUserFieldsFilter | None = SchemaField(
description="Select user fields to include in the response (requires expansions=sender_id or participant_ids).",
placeholder="Enter user fields",
default=None,
advanced=True,
)
class UserExpansionInputs(BlockSchema):
expansions: UserExpansionsFilter | None = SchemaField(
description="Choose what extra information you want to get with user data. Currently only 'pinned_tweet_id' is available to see a user's pinned tweet.",
placeholder="Select extra user information to include",
default=None,
advanced=True,
)
tweet_fields: TweetFieldsFilter | None = SchemaField(
description="Select what tweet information you want to see in pinned tweets. This only works if you select 'pinned_tweet_id' in expansions above.",
placeholder="Choose what details to see in pinned tweets",
default=None,
advanced=True,
)
user_fields: TweetUserFieldsFilter | None = SchemaField(
description="Select what user information you want to see, like username, bio, profile picture, etc.",
placeholder="Choose what user details you want to see",
default=None,
advanced=True,
)
class SpaceExpansionInputs(BlockSchema):
expansions: SpaceExpansionsFilter | None = SchemaField(
description="Choose additional information you want to get with your Twitter Spaces:\n- Select 'Invited_Users' to see who was invited\n- Select 'Speakers' to see who can speak\n- Select 'Creator' to get details about who made the Space\n- Select 'Hosts' to see who's hosting\n- Select 'Topics' to see Space topics",
placeholder="Pick what extra information you want to see about the Space",
default=None,
advanced=True,
)
space_fields: SpaceFieldsFilter | None = SchemaField(
description="Choose what Space details you want to see, such as:\n- Title\n- Start/End times\n- Number of participants\n- Language\n- State (live/scheduled)\n- And more",
placeholder="Choose what Space information you want to get",
default=SpaceFieldsFilter(Space_Title=True, Host_User_IDs=True),
advanced=True,
)
user_fields: TweetUserFieldsFilter | None = SchemaField(
description="Choose what user information you want to see. This works when you select any of these in expansions above:\n- 'Creator' for Space creator details\n- 'Hosts' for host information\n- 'Speakers' for speaker details\n- 'Invited_Users' for invited user information",
placeholder="Pick what details you want to see about the users",
default=None,
advanced=True,
)
class ListExpansionInputs(BlockSchema):
expansions: ListExpansionsFilter | None = SchemaField(
description="Choose what extra information you want to get with your Twitter Lists:\n- Select 'List_Owner_ID' to get details about who owns the list\n\nThis will let you see more details about the list owner when you also select user fields below.",
placeholder="Pick what extra list information you want to see",
default=ListExpansionsFilter(List_Owner_ID=True),
advanced=True,
)
user_fields: TweetUserFieldsFilter | None = SchemaField(
description="Choose what information you want to see about list owners. This only works when you select 'List_Owner_ID' in expansions above.\n\nYou can see things like:\n- Their username\n- Profile picture\n- Account details\n- And more",
placeholder="Select what details you want to see about list owners",
default=TweetUserFieldsFilter(User_ID=True, Username=True),
advanced=True,
)
list_fields: ListFieldsFilter | None = SchemaField(
description="Choose what information you want to see about the Twitter Lists themselves, such as:\n- List name\n- Description\n- Number of followers\n- Number of members\n- Whether it's private\n- Creation date\n- And more",
placeholder="Pick what list details you want to see",
default=ListFieldsFilter(Owner_ID=True),
advanced=True,
)
class TweetTimeWindowInputs(BlockSchema):
start_time: datetime | None = SchemaField(
description="Start time in YYYY-MM-DDTHH:mm:ssZ format",
placeholder="Enter start time",
default=None,
advanced=False,
)
end_time: datetime | None = SchemaField(
description="End time in YYYY-MM-DDTHH:mm:ssZ format",
placeholder="Enter end time",
default=None,
advanced=False,
)
since_id: str | None = SchemaField(
description="Returns results with Tweet ID greater than this (more recent than), we give priority to since_id over start_time",
placeholder="Enter since ID",
default=None,
advanced=True,
)
until_id: str | None = SchemaField(
description="Returns results with Tweet ID less than this (that is, older than), and used with since_id",
placeholder="Enter until ID",
default=None,
advanced=True,
)
sort_order: str | None = SchemaField(
description="Order of returned tweets (recency or relevancy)",
placeholder="Enter sort order",
default=None,
advanced=True,
)

View File

@ -0,0 +1,201 @@
# Todo : Add new Type support
# from typing import cast
# import tweepy
# from tweepy.client import Response
# from backend.blocks.twitter._serializer import IncludesSerializer, ResponseDataSerializer
# from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
# from backend.data.model import SchemaField
# from backend.blocks.twitter._builders import DMExpansionsBuilder
# from backend.blocks.twitter._types import DMEventExpansion, DMEventExpansionInputs, DMEventType, DMMediaField, DMTweetField, TweetUserFields
# from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
# from backend.blocks.twitter._auth import (
# TEST_CREDENTIALS,
# TEST_CREDENTIALS_INPUT,
# TwitterCredentials,
# TwitterCredentialsField,
# TwitterCredentialsInput,
# )
# Require Pro or Enterprise plan [Manual Testing Required]
# class TwitterGetDMEventsBlock(Block):
# """
# Gets a list of Direct Message events for the authenticated user
# """
# class Input(DMEventExpansionInputs):
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
# ["dm.read", "offline.access", "user.read", "tweet.read"]
# )
# dm_conversation_id: str = SchemaField(
# description="The ID of the Direct Message conversation",
# placeholder="Enter conversation ID",
# required=True
# )
# max_results: int = SchemaField(
# description="Maximum number of results to return (1-100)",
# placeholder="Enter max results",
# advanced=True,
# default=10,
# )
# pagination_token: str = SchemaField(
# description="Token for pagination",
# placeholder="Enter pagination token",
# advanced=True,
# default=""
# )
# class Output(BlockSchema):
# # Common outputs
# event_ids: list[str] = SchemaField(description="DM Event IDs")
# event_texts: list[str] = SchemaField(description="DM Event text contents")
# event_types: list[str] = SchemaField(description="Types of DM events")
# next_token: str = SchemaField(description="Token for next page of results")
# # Complete outputs
# data: list[dict] = SchemaField(description="Complete DM events data")
# included: dict = SchemaField(description="Additional data requested via expansions")
# meta: dict = SchemaField(description="Metadata about the response")
# error: str = SchemaField(description="Error message if request failed")
# def __init__(self):
# super().__init__(
# id="dc37a6d4-a62e-11ef-a3a5-03061375737b",
# description="This block retrieves Direct Message events for the authenticated user.",
# categories={BlockCategory.SOCIAL},
# input_schema=TwitterGetDMEventsBlock.Input,
# output_schema=TwitterGetDMEventsBlock.Output,
# test_input={
# "dm_conversation_id": "1234567890",
# "max_results": 10,
# "credentials": TEST_CREDENTIALS_INPUT,
# "expansions": [],
# "event_types": [],
# "media_fields": [],
# "tweet_fields": [],
# "user_fields": []
# },
# test_credentials=TEST_CREDENTIALS,
# test_output=[
# ("event_ids", ["1346889436626259968"]),
# ("event_texts", ["Hello just you..."]),
# ("event_types", ["MessageCreate"]),
# ("next_token", None),
# ("data", [{"id": "1346889436626259968", "text": "Hello just you...", "event_type": "MessageCreate"}]),
# ("included", {}),
# ("meta", {}),
# ("error", "")
# ],
# test_mock={
# "get_dm_events": lambda *args, **kwargs: (
# [{"id": "1346889436626259968", "text": "Hello just you...", "event_type": "MessageCreate"}],
# {},
# {},
# ["1346889436626259968"],
# ["Hello just you..."],
# ["MessageCreate"],
# None
# )
# }
# )
# @staticmethod
# def get_dm_events(
# credentials: TwitterCredentials,
# dm_conversation_id: str,
# max_results: int,
# pagination_token: str,
# expansions: list[DMEventExpansion],
# event_types: list[DMEventType],
# media_fields: list[DMMediaField],
# tweet_fields: list[DMTweetField],
# user_fields: list[TweetUserFields]
# ):
# try:
# client = tweepy.Client(
# bearer_token=credentials.access_token.get_secret_value()
# )
# params = {
# "dm_conversation_id": dm_conversation_id,
# "max_results": max_results,
# "pagination_token": None if pagination_token == "" else pagination_token,
# "user_auth": False
# }
# params = (DMExpansionsBuilder(params)
# .add_expansions(expansions)
# .add_event_types(event_types)
# .add_media_fields(media_fields)
# .add_tweet_fields(tweet_fields)
# .add_user_fields(user_fields)
# .build())
# response = cast(Response, client.get_direct_message_events(**params))
# meta = {}
# event_ids = []
# event_texts = []
# event_types = []
# next_token = None
# if response.meta:
# meta = response.meta
# next_token = meta.get("next_token")
# included = IncludesSerializer.serialize(response.includes)
# data = ResponseDataSerializer.serialize_list(response.data)
# if response.data:
# event_ids = [str(item.id) for item in response.data]
# event_texts = [item.text if hasattr(item, "text") else None for item in response.data]
# event_types = [item.event_type for item in response.data]
# return data, included, meta, event_ids, event_texts, event_types, next_token
# raise Exception("No DM events found")
# except tweepy.TweepyException:
# raise
# def run(
# self,
# input_data: Input,
# *,
# credentials: TwitterCredentials,
# **kwargs,
# ) -> BlockOutput:
# try:
# event_data, included, meta, event_ids, event_texts, event_types, next_token = self.get_dm_events(
# credentials,
# input_data.dm_conversation_id,
# input_data.max_results,
# input_data.pagination_token,
# input_data.expansions,
# input_data.event_types,
# input_data.media_fields,
# input_data.tweet_fields,
# input_data.user_fields
# )
# if event_ids:
# yield "event_ids", event_ids
# if event_texts:
# yield "event_texts", event_texts
# if event_types:
# yield "event_types", event_types
# if next_token:
# yield "next_token", next_token
# if event_data:
# yield "data", event_data
# if included:
# yield "included", included
# if meta:
# yield "meta", meta
# except Exception as e:
# yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,260 @@
# Todo : Add new Type support
# from typing import cast
# import tweepy
# from tweepy.client import Response
# from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
# from backend.data.model import SchemaField
# from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
# from backend.blocks.twitter._auth import (
# TEST_CREDENTIALS,
# TEST_CREDENTIALS_INPUT,
# TwitterCredentials,
# TwitterCredentialsField,
# TwitterCredentialsInput,
# )
# Pro and Enterprise plan [Manual Testing Required]
# class TwitterSendDirectMessageBlock(Block):
# """
# Sends a direct message to a Twitter user
# """
# class Input(BlockSchema):
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
# ["offline.access", "direct_messages.write"]
# )
# participant_id: str = SchemaField(
# description="The User ID of the account to send DM to",
# placeholder="Enter recipient user ID",
# default="",
# advanced=False
# )
# dm_conversation_id: str = SchemaField(
# description="The conversation ID to send message to",
# placeholder="Enter conversation ID",
# default="",
# advanced=False
# )
# text: str = SchemaField(
# description="Text of the Direct Message (up to 10,000 characters)",
# placeholder="Enter message text",
# default="",
# advanced=False
# )
# media_id: str = SchemaField(
# description="Media ID to attach to the message",
# placeholder="Enter media ID",
# default=""
# )
# class Output(BlockSchema):
# dm_event_id: str = SchemaField(description="ID of the sent direct message")
# dm_conversation_id_: str = SchemaField(description="ID of the conversation")
# error: str = SchemaField(description="Error message if sending failed")
# def __init__(self):
# super().__init__(
# id="f32f2786-a62e-11ef-a93d-a3ef199dde7f",
# description="This block sends a direct message to a specified Twitter user.",
# categories={BlockCategory.SOCIAL},
# input_schema=TwitterSendDirectMessageBlock.Input,
# output_schema=TwitterSendDirectMessageBlock.Output,
# test_input={
# "participant_id": "783214",
# "dm_conversation_id": "",
# "text": "Hello from Twitter API",
# "media_id": "",
# "credentials": TEST_CREDENTIALS_INPUT
# },
# test_credentials=TEST_CREDENTIALS,
# test_output=[
# ("dm_event_id", "0987654321"),
# ("dm_conversation_id_", "1234567890"),
# ("error", "")
# ],
# test_mock={
# "send_direct_message": lambda *args, **kwargs: (
# "0987654321",
# "1234567890"
# )
# },
# )
# @staticmethod
# def send_direct_message(
# credentials: TwitterCredentials,
# participant_id: str,
# dm_conversation_id: str,
# text: str,
# media_id: str
# ):
# try:
# client = tweepy.Client(
# bearer_token=credentials.access_token.get_secret_value()
# )
# response = cast(
# Response,
# client.create_direct_message(
# participant_id=None if participant_id == "" else participant_id,
# dm_conversation_id=None if dm_conversation_id == "" else dm_conversation_id,
# text=None if text == "" else text,
# media_id=None if media_id == "" else media_id,
# user_auth=False
# )
# )
# if not response.data:
# raise Exception("Failed to send direct message")
# return response.data["dm_event_id"], response.data["dm_conversation_id"]
# except tweepy.TweepyException:
# raise
# except Exception as e:
# print(f"Unexpected error: {str(e)}")
# raise
# def run(
# self,
# input_data: Input,
# *,
# credentials: TwitterCredentials,
# **kwargs,
# ) -> BlockOutput:
# try:
# dm_event_id, dm_conversation_id = self.send_direct_message(
# credentials,
# input_data.participant_id,
# input_data.dm_conversation_id,
# input_data.text,
# input_data.media_id
# )
# yield "dm_event_id", dm_event_id
# yield "dm_conversation_id", dm_conversation_id
# except Exception as e:
# yield "error", handle_tweepy_exception(e)
# class TwitterCreateDMConversationBlock(Block):
# """
# Creates a new group direct message conversation on Twitter
# """
# class Input(BlockSchema):
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
# ["offline.access", "dm.write","dm.read","tweet.read","user.read"]
# )
# participant_ids: list[str] = SchemaField(
# description="Array of User IDs to create conversation with (max 50)",
# placeholder="Enter participant user IDs",
# default=[],
# advanced=False
# )
# text: str = SchemaField(
# description="Text of the Direct Message (up to 10,000 characters)",
# placeholder="Enter message text",
# default="",
# advanced=False
# )
# media_id: str = SchemaField(
# description="Media ID to attach to the message",
# placeholder="Enter media ID",
# default="",
# advanced=False
# )
# class Output(BlockSchema):
# dm_event_id: str = SchemaField(description="ID of the sent direct message")
# dm_conversation_id: str = SchemaField(description="ID of the conversation")
# error: str = SchemaField(description="Error message if sending failed")
# def __init__(self):
# super().__init__(
# id="ec11cabc-a62e-11ef-8c0e-3fe37ba2ec92",
# description="This block creates a new group DM conversation with specified Twitter users.",
# categories={BlockCategory.SOCIAL},
# input_schema=TwitterCreateDMConversationBlock.Input,
# output_schema=TwitterCreateDMConversationBlock.Output,
# test_input={
# "participant_ids": ["783214", "2244994945"],
# "text": "Hello from Twitter API",
# "media_id": "",
# "credentials": TEST_CREDENTIALS_INPUT
# },
# test_credentials=TEST_CREDENTIALS,
# test_output=[
# ("dm_event_id", "0987654321"),
# ("dm_conversation_id", "1234567890"),
# ("error", "")
# ],
# test_mock={
# "create_dm_conversation": lambda *args, **kwargs: (
# "0987654321",
# "1234567890"
# )
# },
# )
# @staticmethod
# def create_dm_conversation(
# credentials: TwitterCredentials,
# participant_ids: list[str],
# text: str,
# media_id: str
# ):
# try:
# client = tweepy.Client(
# bearer_token=credentials.access_token.get_secret_value()
# )
# response = cast(
# Response,
# client.create_direct_message_conversation(
# participant_ids=participant_ids,
# text=None if text == "" else text,
# media_id=None if media_id == "" else media_id,
# user_auth=False
# )
# )
# if not response.data:
# raise Exception("Failed to create DM conversation")
# return response.data["dm_event_id"], response.data["dm_conversation_id"]
# except tweepy.TweepyException:
# raise
# except Exception as e:
# print(f"Unexpected error: {str(e)}")
# raise
# def run(
# self,
# input_data: Input,
# *,
# credentials: TwitterCredentials,
# **kwargs,
# ) -> BlockOutput:
# try:
# dm_event_id, dm_conversation_id = self.create_dm_conversation(
# credentials,
# input_data.participant_ids,
# input_data.text,
# input_data.media_id
# )
# yield "dm_event_id", dm_event_id
# yield "dm_conversation_id", dm_conversation_id
# except Exception as e:
# yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,470 @@
# from typing import cast
import tweepy
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
# from backend.blocks.twitter._builders import UserExpansionsBuilder
# from backend.blocks.twitter._types import TweetFields, TweetUserFields, UserExpansionInputs, UserExpansions
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
# from tweepy.client import Response
class TwitterUnfollowListBlock(Block):
"""
Unfollows a Twitter list for the authenticated user
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["follows.write", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to unfollow",
placeholder="Enter list ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the unfollow was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="1f43310a-a62f-11ef-8276-2b06a1bbae1a",
description="This block unfollows a specified Twitter list for the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnfollowListBlock.Input,
output_schema=TwitterUnfollowListBlock.Output,
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"unfollow_list": lambda *args, **kwargs: True},
)
@staticmethod
def unfollow_list(credentials: TwitterCredentials, list_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unfollow_list(list_id=list_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unfollow_list(credentials, input_data.list_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterFollowListBlock(Block):
"""
Follows a Twitter list for the authenticated user
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "list.write", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to follow",
placeholder="Enter list ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the follow was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="03d8acf6-a62f-11ef-b17f-b72b04a09e79",
description="This block follows a specified Twitter list for the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterFollowListBlock.Input,
output_schema=TwitterFollowListBlock.Output,
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"follow_list": lambda *args, **kwargs: True},
)
@staticmethod
def follow_list(credentials: TwitterCredentials, list_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.follow_list(list_id=list_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.follow_list(credentials, input_data.list_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
# Enterprise Level [Need to do Manual testing], There is a high possibility that we might get error in this
# Needs Type Input in this
# class TwitterListGetFollowersBlock(Block):
# """
# Gets followers of a specified Twitter list
# """
# class Input(UserExpansionInputs):
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
# ["tweet.read","users.read", "list.read", "offline.access"]
# )
# list_id: str = SchemaField(
# description="The ID of the List to get followers for",
# placeholder="Enter list ID",
# required=True
# )
# max_results: int = SchemaField(
# description="Max number of results per page (1-100)",
# placeholder="Enter max results",
# default=10,
# advanced=True,
# )
# pagination_token: str = SchemaField(
# description="Token for pagination",
# placeholder="Enter pagination token",
# default="",
# advanced=True,
# )
# class Output(BlockSchema):
# user_ids: list[str] = SchemaField(description="List of user IDs of followers")
# usernames: list[str] = SchemaField(description="List of usernames of followers")
# next_token: str = SchemaField(description="Token for next page of results")
# data: list[dict] = SchemaField(description="Complete follower data")
# included: dict = SchemaField(description="Additional data requested via expansions")
# meta: dict = SchemaField(description="Metadata about the response")
# error: str = SchemaField(description="Error message if the request failed")
# def __init__(self):
# super().__init__(
# id="16b289b4-a62f-11ef-95d4-bb29b849eb99",
# description="This block retrieves followers of a specified Twitter list.",
# categories={BlockCategory.SOCIAL},
# input_schema=TwitterListGetFollowersBlock.Input,
# output_schema=TwitterListGetFollowersBlock.Output,
# test_input={
# "list_id": "123456789",
# "max_results": 10,
# "pagination_token": None,
# "credentials": TEST_CREDENTIALS_INPUT,
# "expansions": [],
# "tweet_fields": [],
# "user_fields": []
# },
# test_credentials=TEST_CREDENTIALS,
# test_output=[
# ("user_ids", ["2244994945"]),
# ("usernames", ["testuser"]),
# ("next_token", None),
# ("data", {"followers": [{"id": "2244994945", "username": "testuser"}]}),
# ("included", {}),
# ("meta", {}),
# ("error", "")
# ],
# test_mock={
# "get_list_followers": lambda *args, **kwargs: ({
# "followers": [{"id": "2244994945", "username": "testuser"}]
# }, {}, {}, ["2244994945"], ["testuser"], None)
# }
# )
# @staticmethod
# def get_list_followers(
# credentials: TwitterCredentials,
# list_id: str,
# max_results: int,
# pagination_token: str,
# expansions: list[UserExpansions],
# tweet_fields: list[TweetFields],
# user_fields: list[TweetUserFields]
# ):
# try:
# client = tweepy.Client(
# bearer_token=credentials.access_token.get_secret_value(),
# )
# params = {
# "id": list_id,
# "max_results": max_results,
# "pagination_token": None if pagination_token == "" else pagination_token,
# "user_auth": False
# }
# params = (UserExpansionsBuilder(params)
# .add_expansions(expansions)
# .add_tweet_fields(tweet_fields)
# .add_user_fields(user_fields)
# .build())
# response = cast(
# Response,
# client.get_list_followers(**params)
# )
# meta = {}
# user_ids = []
# usernames = []
# next_token = None
# if response.meta:
# meta = response.meta
# next_token = meta.get("next_token")
# included = IncludesSerializer.serialize(response.includes)
# data = ResponseDataSerializer.serialize_list(response.data)
# if response.data:
# user_ids = [str(item.id) for item in response.data]
# usernames = [item.username for item in response.data]
# return data, included, meta, user_ids, usernames, next_token
# raise Exception("No followers found")
# except tweepy.TweepyException:
# raise
# def run(
# self,
# input_data: Input,
# *,
# credentials: TwitterCredentials,
# **kwargs,
# ) -> BlockOutput:
# try:
# followers_data, included, meta, user_ids, usernames, next_token = self.get_list_followers(
# credentials,
# input_data.list_id,
# input_data.max_results,
# input_data.pagination_token,
# input_data.expansions,
# input_data.tweet_fields,
# input_data.user_fields
# )
# if user_ids:
# yield "user_ids", user_ids
# if usernames:
# yield "usernames", usernames
# if next_token:
# yield "next_token", next_token
# if followers_data:
# yield "data", followers_data
# if included:
# yield "included", included
# if meta:
# yield "meta", meta
# except Exception as e:
# yield "error", handle_tweepy_exception(e)
# class TwitterGetFollowedListsBlock(Block):
# """
# Gets lists followed by a specified Twitter user
# """
# class Input(UserExpansionInputs):
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
# ["follows.read", "users.read", "list.read", "offline.access"]
# )
# user_id: str = SchemaField(
# description="The user ID whose followed Lists to retrieve",
# placeholder="Enter user ID",
# required=True
# )
# max_results: int = SchemaField(
# description="Max number of results per page (1-100)",
# placeholder="Enter max results",
# default=10,
# advanced=True,
# )
# pagination_token: str = SchemaField(
# description="Token for pagination",
# placeholder="Enter pagination token",
# default="",
# advanced=True,
# )
# class Output(BlockSchema):
# list_ids: list[str] = SchemaField(description="List of list IDs")
# list_names: list[str] = SchemaField(description="List of list names")
# data: list[dict] = SchemaField(description="Complete list data")
# includes: dict = SchemaField(description="Additional data requested via expansions")
# meta: dict = SchemaField(description="Metadata about the response")
# next_token: str = SchemaField(description="Token for next page of results")
# error: str = SchemaField(description="Error message if the request failed")
# def __init__(self):
# super().__init__(
# id="0e18bbfc-a62f-11ef-94fa-1f1e174b809e",
# description="This block retrieves all Lists a specified user follows.",
# categories={BlockCategory.SOCIAL},
# input_schema=TwitterGetFollowedListsBlock.Input,
# output_schema=TwitterGetFollowedListsBlock.Output,
# test_input={
# "user_id": "123456789",
# "max_results": 10,
# "pagination_token": None,
# "credentials": TEST_CREDENTIALS_INPUT,
# "expansions": [],
# "tweet_fields": [],
# "user_fields": []
# },
# test_credentials=TEST_CREDENTIALS,
# test_output=[
# ("list_ids", ["12345"]),
# ("list_names", ["Test List"]),
# ("data", {"followed_lists": [{"id": "12345", "name": "Test List"}]}),
# ("includes", {}),
# ("meta", {}),
# ("next_token", None),
# ("error", "")
# ],
# test_mock={
# "get_followed_lists": lambda *args, **kwargs: ({
# "followed_lists": [{"id": "12345", "name": "Test List"}]
# }, {}, {}, ["12345"], ["Test List"], None)
# }
# )
# @staticmethod
# def get_followed_lists(
# credentials: TwitterCredentials,
# user_id: str,
# max_results: int,
# pagination_token: str,
# expansions: list[UserExpansions],
# tweet_fields: list[TweetFields],
# user_fields: list[TweetUserFields]
# ):
# try:
# client = tweepy.Client(
# bearer_token=credentials.access_token.get_secret_value(),
# )
# params = {
# "id": user_id,
# "max_results": max_results,
# "pagination_token": None if pagination_token == "" else pagination_token,
# "user_auth": False
# }
# params = (UserExpansionsBuilder(params)
# .add_expansions(expansions)
# .add_tweet_fields(tweet_fields)
# .add_user_fields(user_fields)
# .build())
# response = cast(
# Response,
# client.get_followed_lists(**params)
# )
# meta = {}
# list_ids = []
# list_names = []
# next_token = None
# if response.meta:
# meta = response.meta
# next_token = meta.get("next_token")
# included = IncludesSerializer.serialize(response.includes)
# data = ResponseDataSerializer.serialize_list(response.data)
# if response.data:
# list_ids = [str(item.id) for item in response.data]
# list_names = [item.name for item in response.data]
# return data, included, meta, list_ids, list_names, next_token
# raise Exception("No followed lists found")
# except tweepy.TweepyException:
# raise
# def run(
# self,
# input_data: Input,
# *,
# credentials: TwitterCredentials,
# **kwargs,
# ) -> BlockOutput:
# try:
# lists_data, included, meta, list_ids, list_names, next_token = self.get_followed_lists(
# credentials,
# input_data.user_id,
# input_data.max_results,
# input_data.pagination_token,
# input_data.expansions,
# input_data.tweet_fields,
# input_data.user_fields
# )
# if list_ids:
# yield "list_ids", list_ids
# if list_names:
# yield "list_names", list_names
# if next_token:
# yield "next_token", next_token
# if lists_data:
# yield "data", lists_data
# if included:
# yield "includes", included
# if meta:
# yield "meta", meta
# except Exception as e:
# yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,348 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import ListExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ListExpansionInputs,
ListExpansionsFilter,
ListFieldsFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterGetListBlock(Block):
"""
Gets information about a Twitter List specified by ID
"""
class Input(ListExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to lookup",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
# Common outputs
id: str = SchemaField(description="ID of the Twitter List")
name: str = SchemaField(description="Name of the Twitter List")
owner_id: str = SchemaField(description="ID of the List owner")
owner_username: str = SchemaField(description="Username of the List owner")
# Complete outputs
data: dict = SchemaField(description="Complete list data")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata about the response")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="34ebc80a-a62f-11ef-9c2a-3fcab6c07079",
description="This block retrieves information about a specified Twitter List.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetListBlock.Input,
output_schema=TwitterGetListBlock.Output,
test_input={
"list_id": "84839422",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"list_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", "84839422"),
("name", "Official Twitter Accounts"),
("owner_id", "2244994945"),
("owner_username", "TwitterAPI"),
("data", {"id": "84839422", "name": "Official Twitter Accounts"}),
],
test_mock={
"get_list": lambda *args, **kwargs: (
{"id": "84839422", "name": "Official Twitter Accounts"},
{},
{},
"2244994945",
"TwitterAPI",
)
},
)
@staticmethod
def get_list(
credentials: TwitterCredentials,
list_id: str,
expansions: ListExpansionsFilter | None,
user_fields: TweetUserFieldsFilter | None,
list_fields: ListFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {"id": list_id, "user_auth": False}
params = (
ListExpansionsBuilder(params)
.add_expansions(expansions)
.add_user_fields(user_fields)
.add_list_fields(list_fields)
.build()
)
response = cast(Response, client.get_list(**params))
meta = {}
owner_id = ""
owner_username = ""
included = {}
if response.includes:
included = IncludesSerializer.serialize(response.includes)
if "users" in included:
owner_id = str(included["users"][0]["id"])
owner_username = included["users"][0]["username"]
if response.meta:
meta = response.meta
if response.data:
data_dict = ResponseDataSerializer.serialize_dict(response.data)
return data_dict, included, meta, owner_id, owner_username
raise Exception("List not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
list_data, included, meta, owner_id, owner_username = self.get_list(
credentials,
input_data.list_id,
input_data.expansions,
input_data.user_fields,
input_data.list_fields,
)
yield "id", str(list_data["id"])
yield "name", list_data["name"]
if owner_id:
yield "owner_id", owner_id
if owner_username:
yield "owner_username", owner_username
yield "data", {"id": list_data["id"], "name": list_data["name"]}
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetOwnedListsBlock(Block):
"""
Gets all Lists owned by the specified user
"""
class Input(ListExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "list.read", "offline.access"]
)
user_id: str = SchemaField(
description="The user ID whose owned Lists to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(
description="Maximum number of results per page (1-100)",
placeholder="Enter max results (default 100)",
advanced=True,
default=10,
)
pagination_token: str | None = SchemaField(
description="Token for pagination",
placeholder="Enter pagination token",
advanced=True,
default="",
)
class Output(BlockSchema):
# Common outputs
list_ids: list[str] = SchemaField(description="List ids of the owned lists")
list_names: list[str] = SchemaField(description="List names of the owned lists")
next_token: str = SchemaField(description="Token for next page of results")
# Complete outputs
data: list[dict] = SchemaField(description="Complete owned lists data")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata about the response")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="2b6bdb26-a62f-11ef-a9ce-ff89c2568726",
description="This block retrieves all Lists owned by a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetOwnedListsBlock.Input,
output_schema=TwitterGetOwnedListsBlock.Output,
test_input={
"user_id": "2244994945",
"max_results": 10,
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"list_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("list_ids", ["84839422"]),
("list_names", ["Official Twitter Accounts"]),
("data", [{"id": "84839422", "name": "Official Twitter Accounts"}]),
],
test_mock={
"get_owned_lists": lambda *args, **kwargs: (
[{"id": "84839422", "name": "Official Twitter Accounts"}],
{},
{},
["84839422"],
["Official Twitter Accounts"],
None,
)
},
)
@staticmethod
def get_owned_lists(
credentials: TwitterCredentials,
user_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: ListExpansionsFilter | None,
user_fields: TweetUserFieldsFilter | None,
list_fields: ListFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": user_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
ListExpansionsBuilder(params)
.add_expansions(expansions)
.add_user_fields(user_fields)
.add_list_fields(list_fields)
.build()
)
response = cast(Response, client.get_owned_lists(**params))
meta = {}
included = {}
list_ids = []
list_names = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
if response.includes:
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
list_ids = [
str(item.id) for item in response.data if hasattr(item, "id")
]
list_names = [
item.name for item in response.data if hasattr(item, "name")
]
return data, included, meta, list_ids, list_names, next_token
raise Exception("User have no owned list")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
list_data, included, meta, list_ids, list_names, next_token = (
self.get_owned_lists(
credentials,
input_data.user_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.user_fields,
input_data.list_fields,
)
)
if list_ids:
yield "list_ids", list_ids
if list_names:
yield "list_names", list_names
if next_token:
yield "next_token", next_token
if list_data:
yield "data", list_data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,527 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import (
ListExpansionsBuilder,
UserExpansionsBuilder,
)
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ListExpansionInputs,
ListExpansionsFilter,
ListFieldsFilter,
TweetFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterRemoveListMemberBlock(Block):
"""
Removes a member from a Twitter List that the authenticated user owns
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "users.read", "tweet.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to remove the member from",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to remove from the List",
placeholder="Enter user ID to remove",
required=True,
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the member was successfully removed"
)
error: str = SchemaField(description="Error message if the removal failed")
def __init__(self):
super().__init__(
id="5a3d1320-a62f-11ef-b7ce-a79e7656bcb0",
description="This block removes a specified user from a Twitter List owned by the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterRemoveListMemberBlock.Input,
output_schema=TwitterRemoveListMemberBlock.Output,
test_input={
"list_id": "123456789",
"user_id": "987654321",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"remove_list_member": lambda *args, **kwargs: True},
)
@staticmethod
def remove_list_member(credentials: TwitterCredentials, list_id: str, user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.remove_list_member(id=list_id, user_id=user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.remove_list_member(
credentials, input_data.list_id, input_data.user_id
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterAddListMemberBlock(Block):
"""
Adds a member to a Twitter List that the authenticated user owns
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "users.read", "tweet.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to add the member to",
placeholder="Enter list ID",
required=True,
)
user_id: str = SchemaField(
description="The ID of the user to add to the List",
placeholder="Enter user ID to add",
required=True,
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the member was successfully added"
)
error: str = SchemaField(description="Error message if the addition failed")
def __init__(self):
super().__init__(
id="3ee8284e-a62f-11ef-84e4-8f6e2cbf0ddb",
description="This block adds a specified user to a Twitter List owned by the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterAddListMemberBlock.Input,
output_schema=TwitterAddListMemberBlock.Output,
test_input={
"list_id": "123456789",
"user_id": "987654321",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"add_list_member": lambda *args, **kwargs: True},
)
@staticmethod
def add_list_member(credentials: TwitterCredentials, list_id: str, user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.add_list_member(id=list_id, user_id=user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.add_list_member(
credentials, input_data.list_id, input_data.user_id
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetListMembersBlock(Block):
"""
Gets the members of a specified Twitter List
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to get members from",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(
description="Maximum number of results per page (1-100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for pagination of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
ids: list[str] = SchemaField(description="List of member user IDs")
usernames: list[str] = SchemaField(description="List of member usernames")
next_token: str = SchemaField(description="Next token for pagination")
data: list[dict] = SchemaField(
description="Complete user data for list members"
)
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata including pagination info")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="4dba046e-a62f-11ef-b69a-87240c84b4c7",
description="This block retrieves the members of a specified Twitter List.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetListMembersBlock.Input,
output_schema=TwitterGetListMembersBlock.Output,
test_input={
"list_id": "123456789",
"max_results": 2,
"pagination_token": None,
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["12345", "67890"]),
("usernames", ["testuser1", "testuser2"]),
(
"data",
[
{"id": "12345", "username": "testuser1"},
{"id": "67890", "username": "testuser2"},
],
),
],
test_mock={
"get_list_members": lambda *args, **kwargs: (
["12345", "67890"],
["testuser1", "testuser2"],
[
{"id": "12345", "username": "testuser1"},
{"id": "67890", "username": "testuser2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_list_members(
credentials: TwitterCredentials,
list_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": list_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_list_members(**params))
meta = {}
included = {}
next_token = None
user_ids = []
usernames = []
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
if response.includes:
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
user_ids = [str(user.id) for user in response.data]
usernames = [user.username for user in response.data]
return user_ids, usernames, data, included, meta, next_token
raise Exception("List members not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, usernames, data, included, meta, next_token = self.get_list_members(
credentials,
input_data.list_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if usernames:
yield "usernames", usernames
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetListMembershipsBlock(Block):
"""
Gets all Lists that a specified user is a member of
"""
class Input(ListExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.read", "offline.access"]
)
user_id: str = SchemaField(
description="The ID of the user whose List memberships to retrieve",
placeholder="Enter user ID",
required=True,
)
max_results: int | None = SchemaField(
description="Maximum number of results per page (1-100)",
placeholder="Enter max results",
advanced=True,
default=10,
)
pagination_token: str | None = SchemaField(
description="Token for pagination of results",
placeholder="Enter pagination token",
advanced=True,
default="",
)
class Output(BlockSchema):
list_ids: list[str] = SchemaField(description="List of list IDs")
next_token: str = SchemaField(description="Next token for pagination")
data: list[dict] = SchemaField(description="List membership data")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata about pagination")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="46e6429c-a62f-11ef-81c0-2b55bc7823ba",
description="This block retrieves all Lists that a specified user is a member of.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetListMembershipsBlock.Input,
output_schema=TwitterGetListMembershipsBlock.Output,
test_input={
"user_id": "123456789",
"max_results": 1,
"pagination_token": None,
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"list_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("list_ids", ["84839422"]),
("data", [{"id": "84839422"}]),
],
test_mock={
"get_list_memberships": lambda *args, **kwargs: (
[{"id": "84839422"}],
{},
{},
["84839422"],
None,
)
},
)
@staticmethod
def get_list_memberships(
credentials: TwitterCredentials,
user_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: ListExpansionsFilter | None,
user_fields: TweetUserFieldsFilter | None,
list_fields: ListFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": user_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
ListExpansionsBuilder(params)
.add_expansions(expansions)
.add_user_fields(user_fields)
.add_list_fields(list_fields)
.build()
)
response = cast(Response, client.get_list_memberships(**params))
meta = {}
included = {}
next_token = None
list_ids = []
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
if response.includes:
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
list_ids = [str(lst.id) for lst in response.data]
return data, included, meta, list_ids, next_token
raise Exception("List memberships not found")
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
data, included, meta, list_ids, next_token = self.get_list_memberships(
credentials,
input_data.user_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.user_fields,
input_data.list_fields,
)
if list_ids:
yield "list_ids", list_ids
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,217 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import TweetExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterGetListTweetsBlock(Block):
"""
Gets tweets from a specified Twitter list
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List whose Tweets you would like to retrieve",
placeholder="Enter list ID",
required=True,
)
max_results: int | None = SchemaField(
description="Maximum number of results per page (1-100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for paginating through results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
# Common outputs
tweet_ids: list[str] = SchemaField(description="List of tweet IDs")
texts: list[str] = SchemaField(description="List of tweet texts")
next_token: str = SchemaField(description="Token for next page of results")
# Complete outputs
data: list[dict] = SchemaField(description="Complete list tweets data")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(
description="Response metadata including pagination tokens"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="6657edb0-a62f-11ef-8c10-0326d832467d",
description="This block retrieves tweets from a specified Twitter list.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetListTweetsBlock.Input,
output_schema=TwitterGetListTweetsBlock.Output,
test_input={
"list_id": "84839422",
"max_results": 1,
"pagination_token": None,
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("tweet_ids", ["1234567890"]),
("texts", ["Test tweet"]),
("data", [{"id": "1234567890", "text": "Test tweet"}]),
],
test_mock={
"get_list_tweets": lambda *args, **kwargs: (
[{"id": "1234567890", "text": "Test tweet"}],
{},
{},
["1234567890"],
["Test tweet"],
None,
)
},
)
@staticmethod
def get_list_tweets(
credentials: TwitterCredentials,
list_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": list_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_list_tweets(**params))
meta = {}
included = {}
tweet_ids = []
texts = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
if response.includes:
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
tweet_ids = [str(item.id) for item in response.data]
texts = [item.text for item in response.data]
return data, included, meta, tweet_ids, texts, next_token
raise Exception("No tweets found in this list")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
list_data, included, meta, tweet_ids, texts, next_token = (
self.get_list_tweets(
credentials,
input_data.list_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
)
if tweet_ids:
yield "tweet_ids", tweet_ids
if texts:
yield "texts", texts
if next_token:
yield "next_token", next_token
if list_data:
yield "data", list_data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,278 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterDeleteListBlock(Block):
"""
Deletes a Twitter List owned by the authenticated user
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to be deleted",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the deletion was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="843c6892-a62f-11ef-a5c8-b71239a78d3b",
description="This block deletes a specified Twitter List owned by the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterDeleteListBlock.Input,
output_schema=TwitterDeleteListBlock.Output,
test_input={"list_id": "1234567890", "credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"delete_list": lambda *args, **kwargs: True},
)
@staticmethod
def delete_list(credentials: TwitterCredentials, list_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.delete_list(id=list_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.delete_list(credentials, input_data.list_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterUpdateListBlock(Block):
"""
Updates a Twitter List owned by the authenticated user
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to be updated",
placeholder="Enter list ID",
advanced=False,
)
name: str | None = SchemaField(
description="New name for the List",
placeholder="Enter list name",
default="",
advanced=False,
)
description: str | None = SchemaField(
description="New description for the List",
placeholder="Enter list description",
default="",
advanced=False,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the update was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="7d12630a-a62f-11ef-90c9-8f5a996612c3",
description="This block updates a specified Twitter List owned by the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUpdateListBlock.Input,
output_schema=TwitterUpdateListBlock.Output,
test_input={
"list_id": "1234567890",
"name": "Updated List Name",
"description": "Updated List Description",
"private": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"update_list": lambda *args, **kwargs: True},
)
@staticmethod
def update_list(
credentials: TwitterCredentials,
list_id: str,
name: str | None,
description: str | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.update_list(
id=list_id,
name=None if name == "" else name,
description=None if description == "" else description,
user_auth=False,
)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.update_list(
credentials, input_data.list_id, input_data.name, input_data.description
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterCreateListBlock(Block):
"""
Creates a Twitter List owned by the authenticated user
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "offline.access"]
)
name: str = SchemaField(
description="The name of the List to be created",
placeholder="Enter list name",
advanced=False,
default="",
)
description: str | None = SchemaField(
description="Description of the List",
placeholder="Enter list description",
advanced=False,
default="",
)
private: bool = SchemaField(
description="Whether the List should be private",
advanced=False,
default=False,
)
class Output(BlockSchema):
url: str = SchemaField(description="URL of the created list")
list_id: str = SchemaField(description="ID of the created list")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="724148ba-a62f-11ef-89ba-5349b813ef5f",
description="This block creates a new Twitter List for the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterCreateListBlock.Input,
output_schema=TwitterCreateListBlock.Output,
test_input={
"name": "New List Name",
"description": "New List Description",
"private": True,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("list_id", "1234567890"),
("url", "https://twitter.com/i/lists/1234567890"),
],
test_mock={"create_list": lambda *args, **kwargs: ("1234567890")},
)
@staticmethod
def create_list(
credentials: TwitterCredentials,
name: str,
description: str | None,
private: bool,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
response = cast(
Response,
client.create_list(
name=None if name == "" else name,
description=None if description == "" else description,
private=private,
user_auth=False,
),
)
list_id = str(response.data["id"])
return list_id
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
list_id = self.create_list(
credentials, input_data.name, input_data.description, input_data.private
)
yield "list_id", list_id
yield "url", f"https://twitter.com/i/lists/{list_id}"
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,285 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import ListExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ListExpansionInputs,
ListExpansionsFilter,
ListFieldsFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterUnpinListBlock(Block):
"""
Enables the authenticated user to unpin a List.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "users.read", "tweet.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to unpin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the unpin was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="a099c034-a62f-11ef-9622-47d0ceb73555",
description="This block allows the authenticated user to unpin a specified List.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnpinListBlock.Input,
output_schema=TwitterUnpinListBlock.Output,
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"unpin_list": lambda *args, **kwargs: True},
)
@staticmethod
def unpin_list(credentials: TwitterCredentials, list_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unpin_list(list_id=list_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unpin_list(credentials, input_data.list_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterPinListBlock(Block):
"""
Enables the authenticated user to pin a List.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["list.write", "users.read", "tweet.read", "offline.access"]
)
list_id: str = SchemaField(
description="The ID of the List to pin",
placeholder="Enter list ID",
required=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the pin was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="8ec16e48-a62f-11ef-9f35-f3d6de43a802",
description="This block allows the authenticated user to pin a specified List.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterPinListBlock.Input,
output_schema=TwitterPinListBlock.Output,
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"pin_list": lambda *args, **kwargs: True},
)
@staticmethod
def pin_list(credentials: TwitterCredentials, list_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.pin_list(list_id=list_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.pin_list(credentials, input_data.list_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetPinnedListsBlock(Block):
"""
Returns the Lists pinned by the authenticated user.
"""
class Input(ListExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["lists.read", "users.read", "offline.access"]
)
class Output(BlockSchema):
list_ids: list[str] = SchemaField(description="List IDs of the pinned lists")
list_names: list[str] = SchemaField(
description="List names of the pinned lists"
)
data: list[dict] = SchemaField(
description="Response data containing pinned lists"
)
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata about the response")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="97e03aae-a62f-11ef-bc53-5b89cb02888f",
description="This block returns the Lists pinned by the authenticated user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetPinnedListsBlock.Input,
output_schema=TwitterGetPinnedListsBlock.Output,
test_input={
"expansions": None,
"list_fields": None,
"user_fields": None,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("list_ids", ["84839422"]),
("list_names", ["Twitter List"]),
("data", [{"id": "84839422", "name": "Twitter List"}]),
],
test_mock={
"get_pinned_lists": lambda *args, **kwargs: (
[{"id": "84839422", "name": "Twitter List"}],
{},
{},
["84839422"],
["Twitter List"],
)
},
)
@staticmethod
def get_pinned_lists(
credentials: TwitterCredentials,
expansions: ListExpansionsFilter | None,
user_fields: TweetUserFieldsFilter | None,
list_fields: ListFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {"user_auth": False}
params = (
ListExpansionsBuilder(params)
.add_expansions(expansions)
.add_user_fields(user_fields)
.add_list_fields(list_fields)
.build()
)
response = cast(Response, client.get_pinned_lists(**params))
meta = {}
included = {}
list_ids = []
list_names = []
if response.meta:
meta = response.meta
if response.includes:
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
list_ids = [str(item.id) for item in response.data]
list_names = [item.name for item in response.data]
return data, included, meta, list_ids, list_names
raise Exception("Lists not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
list_data, included, meta, list_ids, list_names = self.get_pinned_lists(
credentials,
input_data.expansions,
input_data.user_fields,
input_data.list_fields,
)
if list_ids:
yield "list_ids", list_ids
if list_names:
yield "list_names", list_names
if list_data:
yield "data", list_data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,195 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import SpaceExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
SpaceExpansionInputs,
SpaceExpansionsFilter,
SpaceFieldsFilter,
SpaceStatesFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterSearchSpacesBlock(Block):
"""
Returns live or scheduled Spaces matching specified search terms [for a week only]
"""
class Input(SpaceExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["spaces.read", "users.read", "tweet.read", "offline.access"]
)
query: str = SchemaField(
description="Search term to find in Space titles",
placeholder="Enter search query",
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (1-100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
state: SpaceStatesFilter = SchemaField(
description="Type of Spaces to return (live, scheduled, or all)",
placeholder="Enter state filter",
default=SpaceStatesFilter.all,
)
class Output(BlockSchema):
# Common outputs that user commonly uses
ids: list[str] = SchemaField(description="List of space IDs")
titles: list[str] = SchemaField(description="List of space titles")
host_ids: list = SchemaField(description="List of host IDs")
next_token: str = SchemaField(description="Next token for pagination")
# Complete outputs for advanced use
data: list[dict] = SchemaField(description="Complete space data")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata including pagination info")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="aaefdd48-a62f-11ef-a73c-3f44df63e276",
description="This block searches for Twitter Spaces based on specified terms.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterSearchSpacesBlock.Input,
output_schema=TwitterSearchSpacesBlock.Output,
test_input={
"query": "tech",
"max_results": 1,
"state": "live",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"space_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1234"]),
("titles", ["Tech Talk"]),
("host_ids", ["5678"]),
("data", [{"id": "1234", "title": "Tech Talk", "host_ids": ["5678"]}]),
],
test_mock={
"search_spaces": lambda *args, **kwargs: (
[{"id": "1234", "title": "Tech Talk", "host_ids": ["5678"]}],
{},
{},
["1234"],
["Tech Talk"],
["5678"],
None,
)
},
)
@staticmethod
def search_spaces(
credentials: TwitterCredentials,
query: str,
max_results: int | None,
state: SpaceStatesFilter,
expansions: SpaceExpansionsFilter | None,
space_fields: SpaceFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {"query": query, "max_results": max_results, "state": state.value}
params = (
SpaceExpansionsBuilder(params)
.add_expansions(expansions)
.add_space_fields(space_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.search_spaces(**params))
meta = {}
next_token = ""
if response.meta:
meta = response.meta
if "next_token" in meta:
next_token = meta["next_token"]
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
ids = [str(space["id"]) for space in response.data if "id" in space]
titles = [space["title"] for space in data if "title" in space]
host_ids = [space["host_ids"] for space in data if "host_ids" in space]
return data, included, meta, ids, titles, host_ids, next_token
raise Exception("Spaces not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
data, included, meta, ids, titles, host_ids, next_token = (
self.search_spaces(
credentials,
input_data.query,
input_data.max_results,
input_data.state,
input_data.expansions,
input_data.space_fields,
input_data.user_fields,
)
)
if ids:
yield "ids", ids
if titles:
yield "titles", titles
if host_ids:
yield "host_ids", host_ids
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "includes", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,651 @@
from typing import Literal, Union, cast
import tweepy
from pydantic import BaseModel
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import (
SpaceExpansionsBuilder,
TweetExpansionsBuilder,
UserExpansionsBuilder,
)
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
SpaceExpansionInputs,
SpaceExpansionsFilter,
SpaceFieldsFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class SpaceList(BaseModel):
discriminator: Literal["space_list"]
space_ids: list[str] = SchemaField(
description="List of Space IDs to lookup (up to 100)",
placeholder="Enter Space IDs",
default=[],
advanced=False,
)
class UserList(BaseModel):
discriminator: Literal["user_list"]
user_ids: list[str] = SchemaField(
description="List of user IDs to lookup their Spaces (up to 100)",
placeholder="Enter user IDs",
default=[],
advanced=False,
)
class TwitterGetSpacesBlock(Block):
"""
Gets information about multiple Twitter Spaces specified by Space IDs or creator user IDs
"""
class Input(SpaceExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["spaces.read", "users.read", "offline.access"]
)
identifier: Union[SpaceList, UserList] = SchemaField(
discriminator="discriminator",
description="Choose whether to lookup spaces by their IDs or by creator user IDs",
advanced=False,
)
class Output(BlockSchema):
# Common outputs
ids: list[str] = SchemaField(description="List of space IDs")
titles: list[str] = SchemaField(description="List of space titles")
# Complete outputs for advanced use
data: list[dict] = SchemaField(description="Complete space data")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="d75bd7d8-a62f-11ef-b0d8-c7a9496f617f",
description="This block retrieves information about multiple Twitter Spaces.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetSpacesBlock.Input,
output_schema=TwitterGetSpacesBlock.Output,
test_input={
"identifier": {
"discriminator": "space_list",
"space_ids": ["1DXxyRYNejbKM"],
},
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"space_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1DXxyRYNejbKM"]),
("titles", ["Test Space"]),
(
"data",
[
{
"id": "1DXxyRYNejbKM",
"title": "Test Space",
"host_id": "1234567",
}
],
),
],
test_mock={
"get_spaces": lambda *args, **kwargs: (
[
{
"id": "1DXxyRYNejbKM",
"title": "Test Space",
"host_id": "1234567",
}
],
{},
["1DXxyRYNejbKM"],
["Test Space"],
)
},
)
@staticmethod
def get_spaces(
credentials: TwitterCredentials,
identifier: Union[SpaceList, UserList],
expansions: SpaceExpansionsFilter | None,
space_fields: SpaceFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"ids": (
identifier.space_ids if isinstance(identifier, SpaceList) else None
),
"user_ids": (
identifier.user_ids if isinstance(identifier, UserList) else None
),
}
params = (
SpaceExpansionsBuilder(params)
.add_expansions(expansions)
.add_space_fields(space_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_spaces(**params))
ids = []
titles = []
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
ids = [space["id"] for space in data if "id" in space]
titles = [space["title"] for space in data if "title" in space]
return data, included, ids, titles
raise Exception("No spaces found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
data, included, ids, titles = self.get_spaces(
credentials,
input_data.identifier,
input_data.expansions,
input_data.space_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if titles:
yield "titles", titles
if data:
yield "data", data
if included:
yield "includes", included
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetSpaceByIdBlock(Block):
"""
Gets information about a single Twitter Space specified by Space ID
"""
class Input(SpaceExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["spaces.read", "users.read", "offline.access"]
)
space_id: str = SchemaField(
description="Space ID to lookup",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
# Common outputs
id: str = SchemaField(description="Space ID")
title: str = SchemaField(description="Space title")
host_ids: list[str] = SchemaField(description="Host ID")
# Complete outputs for advanced use
data: dict = SchemaField(description="Complete space data")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="c79700de-a62f-11ef-ab20-fb32bf9d5a9d",
description="This block retrieves information about a single Twitter Space.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetSpaceByIdBlock.Input,
output_schema=TwitterGetSpaceByIdBlock.Output,
test_input={
"space_id": "1DXxyRYNejbKM",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"space_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", "1DXxyRYNejbKM"),
("title", "Test Space"),
("host_ids", ["1234567"]),
(
"data",
{
"id": "1DXxyRYNejbKM",
"title": "Test Space",
"host_ids": ["1234567"],
},
),
],
test_mock={
"get_space": lambda *args, **kwargs: (
{
"id": "1DXxyRYNejbKM",
"title": "Test Space",
"host_ids": ["1234567"],
},
{},
)
},
)
@staticmethod
def get_space(
credentials: TwitterCredentials,
space_id: str,
expansions: SpaceExpansionsFilter | None,
space_fields: SpaceFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": space_id,
}
params = (
SpaceExpansionsBuilder(params)
.add_expansions(expansions)
.add_space_fields(space_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_space(**params))
includes = {}
if response.includes:
for key, value in response.includes.items():
if isinstance(value, list):
includes[key] = [
item.data if hasattr(item, "data") else item
for item in value
]
else:
includes[key] = value.data if hasattr(value, "data") else value
data = {}
if response.data:
for key, value in response.data.items():
if isinstance(value, list):
data[key] = [
item.data if hasattr(item, "data") else item
for item in value
]
else:
data[key] = value.data if hasattr(value, "data") else value
return data, includes
raise Exception("Space not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
space_data, includes = self.get_space(
credentials,
input_data.space_id,
input_data.expansions,
input_data.space_fields,
input_data.user_fields,
)
# Common outputs
if space_data:
if "id" in space_data:
yield "id", space_data.get("id")
if "title" in space_data:
yield "title", space_data.get("title")
if "host_ids" in space_data:
yield "host_ids", space_data.get("host_ids")
if space_data:
yield "data", space_data
if includes:
yield "includes", includes
except Exception as e:
yield "error", handle_tweepy_exception(e)
# Not tested yet, might have some problem
class TwitterGetSpaceBuyersBlock(Block):
"""
Gets list of users who purchased a ticket to the requested Space
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["spaces.read", "users.read", "offline.access"]
)
space_id: str = SchemaField(
description="Space ID to lookup buyers for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
# Common outputs
buyer_ids: list[str] = SchemaField(description="List of buyer IDs")
usernames: list[str] = SchemaField(description="List of buyer usernames")
# Complete outputs for advanced use
data: list[dict] = SchemaField(description="Complete space buyers data")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="c1c121a8-a62f-11ef-8b0e-d7b85f96a46f",
description="This block retrieves a list of users who purchased tickets to a Twitter Space.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetSpaceBuyersBlock.Input,
output_schema=TwitterGetSpaceBuyersBlock.Output,
test_input={
"space_id": "1DXxyRYNejbKM",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("buyer_ids", ["2244994945"]),
("usernames", ["testuser"]),
(
"data",
[{"id": "2244994945", "username": "testuser", "name": "Test User"}],
),
],
test_mock={
"get_space_buyers": lambda *args, **kwargs: (
[{"id": "2244994945", "username": "testuser", "name": "Test User"}],
{},
["2244994945"],
["testuser"],
)
},
)
@staticmethod
def get_space_buyers(
credentials: TwitterCredentials,
space_id: str,
expansions: UserExpansionsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": space_id,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_space_buyers(**params))
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
buyer_ids = [buyer["id"] for buyer in data]
usernames = [buyer["username"] for buyer in data]
return data, included, buyer_ids, usernames
raise Exception("No buyers found for this Space")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
buyers_data, included, buyer_ids, usernames = self.get_space_buyers(
credentials,
input_data.space_id,
input_data.expansions,
input_data.user_fields,
)
if buyer_ids:
yield "buyer_ids", buyer_ids
if usernames:
yield "usernames", usernames
if buyers_data:
yield "data", buyers_data
if included:
yield "includes", included
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetSpaceTweetsBlock(Block):
"""
Gets list of Tweets shared in the requested Space
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["spaces.read", "users.read", "offline.access"]
)
space_id: str = SchemaField(
description="Space ID to lookup tweets for",
placeholder="Enter Space ID",
required=True,
)
class Output(BlockSchema):
# Common outputs
tweet_ids: list[str] = SchemaField(description="List of tweet IDs")
texts: list[str] = SchemaField(description="List of tweet texts")
# Complete outputs for advanced use
data: list[dict] = SchemaField(description="Complete space tweets data")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Response metadata")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="b69731e6-a62f-11ef-b2d4-1bf14dd6aee4",
description="This block retrieves tweets shared in a Twitter Space.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetSpaceTweetsBlock.Input,
output_schema=TwitterGetSpaceTweetsBlock.Output,
test_input={
"space_id": "1DXxyRYNejbKM",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("tweet_ids", ["1234567890"]),
("texts", ["Test tweet"]),
("data", [{"id": "1234567890", "text": "Test tweet"}]),
],
test_mock={
"get_space_tweets": lambda *args, **kwargs: (
[{"id": "1234567890", "text": "Test tweet"}], # data
{},
["1234567890"],
["Test tweet"],
{},
)
},
)
@staticmethod
def get_space_tweets(
credentials: TwitterCredentials,
space_id: str,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": space_id,
}
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_space_tweets(**params))
included = IncludesSerializer.serialize(response.includes)
if response.data:
data = ResponseDataSerializer.serialize_list(response.data)
tweet_ids = [str(tweet["id"]) for tweet in data]
texts = [tweet["text"] for tweet in data]
meta = response.meta or {}
return data, included, tweet_ids, texts, meta
raise Exception("No tweets found for this Space")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
tweets_data, included, tweet_ids, texts, meta = self.get_space_tweets(
credentials,
input_data.space_id,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
if tweet_ids:
yield "tweet_ids", tweet_ids
if texts:
yield "texts", texts
if tweets_data:
yield "data", tweets_data
if included:
yield "includes", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,20 @@
import tweepy
def handle_tweepy_exception(e: Exception) -> str:
if isinstance(e, tweepy.BadRequest):
return f"Bad Request (400): {str(e)}"
elif isinstance(e, tweepy.Unauthorized):
return f"Unauthorized (401): {str(e)}"
elif isinstance(e, tweepy.Forbidden):
return f"Forbidden (403): {str(e)}"
elif isinstance(e, tweepy.NotFound):
return f"Not Found (404): {str(e)}"
elif isinstance(e, tweepy.TooManyRequests):
return f"Too Many Requests (429): {str(e)}"
elif isinstance(e, tweepy.TwitterServerError):
return f"Twitter Server Error (5xx): {str(e)}"
elif isinstance(e, tweepy.TweepyException):
return f"Tweepy Error: {str(e)}"
else:
return f"Unexpected error: {str(e)}"

View File

@ -0,0 +1,372 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import TweetExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterBookmarkTweetBlock(Block):
"""
Bookmark a tweet on Twitter
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "bookmark.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to bookmark",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the bookmark was successful")
error: str = SchemaField(description="Error message if the bookmark failed")
def __init__(self):
super().__init__(
id="f33d67be-a62f-11ef-a797-ff83ec29ee8e",
description="This block bookmarks a tweet on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterBookmarkTweetBlock.Input,
output_schema=TwitterBookmarkTweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"bookmark_tweet": lambda *args, **kwargs: True},
)
@staticmethod
def bookmark_tweet(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.bookmark(tweet_id)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.bookmark_tweet(credentials, input_data.tweet_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetBookmarkedTweetsBlock(Block):
"""
Get All your bookmarked tweets from Twitter
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "bookmark.read", "users.read", "offline.access"]
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (1-100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for pagination",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
id: list[str] = SchemaField(description="All Tweet IDs")
text: list[str] = SchemaField(description="All Tweet texts")
userId: list[str] = SchemaField(description="IDs of the tweet authors")
userName: list[str] = SchemaField(description="Usernames of the tweet authors")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
next_token: str = SchemaField(description="Next token for pagination")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="ed26783e-a62f-11ef-9a21-c77c57dd8a1f",
description="This block retrieves bookmarked tweets from Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetBookmarkedTweetsBlock.Input,
output_schema=TwitterGetBookmarkedTweetsBlock.Output,
test_input={
"max_results": 2,
"pagination_token": None,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", ["1234567890"]),
("text", ["Test tweet"]),
("userId", ["12345"]),
("userName", ["testuser"]),
("data", [{"id": "1234567890", "text": "Test tweet"}]),
],
test_mock={
"get_bookmarked_tweets": lambda *args, **kwargs: (
["1234567890"],
["Test tweet"],
["12345"],
["testuser"],
[{"id": "1234567890", "text": "Test tweet"}],
{},
{},
None,
)
},
)
@staticmethod
def get_bookmarked_tweets(
credentials: TwitterCredentials,
max_results: int | None,
pagination_token: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
}
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(
Response,
client.get_bookmarks(**params),
)
meta = {}
tweet_ids = []
tweet_texts = []
user_ids = []
user_names = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
if "users" in included:
for user in included["users"]:
user_ids.append(str(user["id"]))
user_names.append(user["username"])
return (
tweet_ids,
tweet_texts,
user_ids,
user_names,
data,
included,
meta,
next_token,
)
raise Exception("No bookmarked tweets found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, user_ids, user_names, data, included, meta, next_token = (
self.get_bookmarked_tweets(
credentials,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
)
if ids:
yield "id", ids
if texts:
yield "text", texts
if user_ids:
yield "userId", user_ids
if user_names:
yield "userName", user_names
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
if next_token:
yield "next_token", next_token
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterRemoveBookmarkTweetBlock(Block):
"""
Remove a bookmark for a tweet on Twitter
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "bookmark.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to remove bookmark from",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the bookmark was successfully removed"
)
error: str = SchemaField(
description="Error message if the bookmark removal failed"
)
def __init__(self):
super().__init__(
id="e4100684-a62f-11ef-9be9-770cb41a2616",
description="This block removes a bookmark from a tweet on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterRemoveBookmarkTweetBlock.Input,
output_schema=TwitterRemoveBookmarkTweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"remove_bookmark_tweet": lambda *args, **kwargs: True},
)
@staticmethod
def remove_bookmark_tweet(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.remove_bookmark(tweet_id)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.remove_bookmark_tweet(credentials, input_data.tweet_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,154 @@
import tweepy
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterHideReplyBlock(Block):
"""
Hides a reply of one of your tweets
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "tweet.moderate.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet reply to hide",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="07d58b3e-a630-11ef-a030-93701d1a465e",
description="This block hides a reply to a tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterHideReplyBlock.Input,
output_schema=TwitterHideReplyBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"hide_reply": lambda *args, **kwargs: True},
)
@staticmethod
def hide_reply(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.hide_reply(id=tweet_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.hide_reply(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterUnhideReplyBlock(Block):
"""
Unhides a reply to a tweet
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "tweet.moderate.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet reply to unhide",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="fcf9e4e4-a62f-11ef-9d85-57d3d06b616a",
description="This block unhides a reply to a tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnhideReplyBlock.Input,
output_schema=TwitterUnhideReplyBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"unhide_reply": lambda *args, **kwargs: True},
)
@staticmethod
def unhide_reply(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unhide_reply(id=tweet_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unhide_reply(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,576 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import (
TweetExpansionsBuilder,
UserExpansionsBuilder,
)
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterLikeTweetBlock(Block):
"""
Likes a tweet
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "like.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to like",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="4d0b4c5c-a630-11ef-8e08-1b14c507b347",
description="This block likes a tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterLikeTweetBlock.Input,
output_schema=TwitterLikeTweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"like_tweet": lambda *args, **kwargs: True},
)
@staticmethod
def like_tweet(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.like(tweet_id=tweet_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.like_tweet(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetLikingUsersBlock(Block):
"""
Gets information about users who liked a one of your tweet
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "like.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to get liking users for",
placeholder="Enter tweet ID",
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (1-100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for getting next/previous page of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
id: list[str] = SchemaField(description="All User IDs who liked the tweet")
username: list[str] = SchemaField(
description="All User usernames who liked the tweet"
)
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="34275000-a630-11ef-b01e-5f00d9077c08",
description="This block gets information about users who liked a tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetLikingUsersBlock.Input,
output_schema=TwitterGetLikingUsersBlock.Output,
test_input={
"tweet_id": "1234567890",
"max_results": 1,
"pagination_token": None,
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", ["1234567890"]),
("username", ["testuser"]),
("data", [{"id": "1234567890", "username": "testuser"}]),
],
test_mock={
"get_liking_users": lambda *args, **kwargs: (
["1234567890"],
["testuser"],
[{"id": "1234567890", "username": "testuser"}],
{},
{},
None,
)
},
)
@staticmethod
def get_liking_users(
credentials: TwitterCredentials,
tweet_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": tweet_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_liking_users(**params))
if not response.data and not response.meta:
raise Exception("No liking users found")
meta = {}
user_ids = []
usernames = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
user_ids = [str(user.id) for user in response.data]
usernames = [user.username for user in response.data]
return user_ids, usernames, data, included, meta, next_token
raise Exception("No liking users found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, usernames, data, included, meta, next_token = self.get_liking_users(
credentials,
input_data.tweet_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "id", ids
if usernames:
yield "username", usernames
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetLikedTweetsBlock(Block):
"""
Gets information about tweets liked by you
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "like.read", "offline.access"]
)
user_id: str = SchemaField(
description="ID of the user to get liked tweets for",
placeholder="Enter user ID",
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (5-100)",
placeholder="100",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for getting next/previous page of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list[str] = SchemaField(description="All Tweet IDs")
texts: list[str] = SchemaField(description="All Tweet texts")
userIds: list[str] = SchemaField(
description="List of user ids that authored the tweets"
)
userNames: list[str] = SchemaField(
description="List of user names that authored the tweets"
)
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="292e7c78-a630-11ef-9f40-df5dffaca106",
description="This block gets information about tweets liked by a user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetLikedTweetsBlock.Input,
output_schema=TwitterGetLikedTweetsBlock.Output,
test_input={
"user_id": "1234567890",
"max_results": 2,
"pagination_token": None,
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["12345", "67890"]),
("texts", ["Tweet 1", "Tweet 2"]),
("userIds", ["67890", "67891"]),
("userNames", ["testuser1", "testuser2"]),
(
"data",
[
{"id": "12345", "text": "Tweet 1"},
{"id": "67890", "text": "Tweet 2"},
],
),
],
test_mock={
"get_liked_tweets": lambda *args, **kwargs: (
["12345", "67890"],
["Tweet 1", "Tweet 2"],
["67890", "67891"],
["testuser1", "testuser2"],
[
{"id": "12345", "text": "Tweet 1"},
{"id": "67890", "text": "Tweet 2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_liked_tweets(
credentials: TwitterCredentials,
user_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": user_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_liked_tweets(**params))
if not response.data and not response.meta:
raise Exception("No liked tweets found")
meta = {}
tweet_ids = []
tweet_texts = []
user_ids = []
user_names = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
if "users" in response.includes:
user_ids = [str(user["id"]) for user in response.includes["users"]]
user_names = [
user["username"] for user in response.includes["users"]
]
return (
tweet_ids,
tweet_texts,
user_ids,
user_names,
data,
included,
meta,
next_token,
)
raise Exception("No liked tweets found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, user_ids, user_names, data, included, meta, next_token = (
self.get_liked_tweets(
credentials,
input_data.user_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
)
if ids:
yield "ids", ids
if texts:
yield "texts", texts
if user_ids:
yield "userIds", user_ids
if user_names:
yield "userNames", user_names
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterUnlikeTweetBlock(Block):
"""
Unlikes a tweet that was previously liked
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "like.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to unlike",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the operation was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="1ed5eab8-a630-11ef-8e21-cbbbc80cbb85",
description="This block unlikes a tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnlikeTweetBlock.Input,
output_schema=TwitterUnlikeTweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"unlike_tweet": lambda *args, **kwargs: True},
)
@staticmethod
def unlike_tweet(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unlike(tweet_id=tweet_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unlike_tweet(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,545 @@
from datetime import datetime
from typing import List, Literal, Optional, Union, cast
import tweepy
from pydantic import BaseModel
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import (
TweetDurationBuilder,
TweetExpansionsBuilder,
TweetPostBuilder,
TweetSearchBuilder,
)
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetReplySettingsFilter,
TweetTimeWindowInputs,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class Media(BaseModel):
discriminator: Literal["media"]
media_ids: Optional[List[str]] = None
media_tagged_user_ids: Optional[List[str]] = None
class DeepLink(BaseModel):
discriminator: Literal["deep_link"]
direct_message_deep_link: Optional[str] = None
class Poll(BaseModel):
discriminator: Literal["poll"]
poll_options: Optional[List[str]] = None
poll_duration_minutes: Optional[int] = None
class Place(BaseModel):
discriminator: Literal["place"]
place_id: Optional[str] = None
class Quote(BaseModel):
discriminator: Literal["quote"]
quote_tweet_id: Optional[str] = None
class TwitterPostTweetBlock(Block):
"""
Create a tweet on Twitter with the option to include one additional element such as a media, quote, or deep link.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "tweet.write", "users.read", "offline.access"]
)
tweet_text: str | None = SchemaField(
description="Text of the tweet to post",
placeholder="Enter your tweet",
default=None,
advanced=False,
)
for_super_followers_only: bool = SchemaField(
description="Tweet exclusively for Super Followers",
placeholder="Enter for super followers only",
advanced=True,
default=False,
)
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None = SchemaField(
discriminator="discriminator",
description="Additional tweet data (media, deep link, poll, place or quote)",
advanced=True,
)
exclude_reply_user_ids: Optional[List[str]] = SchemaField(
description="User IDs to exclude from reply Tweet thread. [ex - 6253282]",
placeholder="Enter user IDs to exclude",
advanced=True,
default=None,
)
in_reply_to_tweet_id: Optional[str] = SchemaField(
description="Tweet ID being replied to. Please note that in_reply_to_tweet_id needs to be in the request if exclude_reply_user_ids is present",
default=None,
placeholder="Enter in reply to tweet ID",
advanced=True,
)
reply_settings: TweetReplySettingsFilter = SchemaField(
description="Who can reply to the Tweet (mentionedUsers or following)",
placeholder="Enter reply settings",
advanced=True,
default=TweetReplySettingsFilter(All_Users=True),
)
class Output(BlockSchema):
tweet_id: str = SchemaField(description="ID of the created tweet")
tweet_url: str = SchemaField(description="URL to the tweet")
error: str = SchemaField(
description="Error message if the tweet posting failed"
)
def __init__(self):
super().__init__(
id="7bb0048a-a630-11ef-aeb8-abc0dadb9b12",
description="This block posts a tweet on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterPostTweetBlock.Input,
output_schema=TwitterPostTweetBlock.Output,
test_input={
"tweet_text": "This is a test tweet.",
"credentials": TEST_CREDENTIALS_INPUT,
"attachment": {
"discriminator": "deep_link",
"direct_message_deep_link": "https://twitter.com/messages/compose",
},
"for_super_followers_only": False,
"exclude_reply_user_ids": [],
"in_reply_to_tweet_id": "",
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("tweet_id", "1234567890"),
("tweet_url", "https://twitter.com/user/status/1234567890"),
],
test_mock={
"post_tweet": lambda *args, **kwargs: (
"1234567890",
"https://twitter.com/user/status/1234567890",
)
},
)
def post_tweet(
self,
credentials: TwitterCredentials,
input_txt: str | None,
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None,
for_super_followers_only: bool,
exclude_reply_user_ids: Optional[List[str]],
in_reply_to_tweet_id: Optional[str],
reply_settings: TweetReplySettingsFilter,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = (
TweetPostBuilder()
.add_text(input_txt)
.add_super_followers(for_super_followers_only)
.add_reply_settings(
exclude_reply_user_ids or [],
in_reply_to_tweet_id or "",
reply_settings,
)
)
if isinstance(attachment, Media):
params.add_media(
attachment.media_ids or [], attachment.media_tagged_user_ids or []
)
elif isinstance(attachment, DeepLink):
params.add_deep_link(attachment.direct_message_deep_link or "")
elif isinstance(attachment, Poll):
params.add_poll_options(attachment.poll_options or [])
params.add_poll_duration(attachment.poll_duration_minutes or 0)
elif isinstance(attachment, Place):
params.add_place(attachment.place_id or "")
elif isinstance(attachment, Quote):
params.add_quote(attachment.quote_tweet_id or "")
tweet = cast(Response, client.create_tweet(**params.build()))
if not tweet.data:
raise Exception("Failed to create tweet")
tweet_id = tweet.data["id"]
tweet_url = f"https://twitter.com/user/status/{tweet_id}"
return str(tweet_id), tweet_url
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
tweet_id, tweet_url = self.post_tweet(
credentials,
input_data.tweet_text,
input_data.attachment,
input_data.for_super_followers_only,
input_data.exclude_reply_user_ids,
input_data.in_reply_to_tweet_id,
input_data.reply_settings,
)
yield "tweet_id", tweet_id
yield "tweet_url", tweet_url
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterDeleteTweetBlock(Block):
"""
Deletes a tweet on Twitter using twitter Id
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "tweet.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to delete",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the tweet was successfully deleted"
)
error: str = SchemaField(
description="Error message if the tweet deletion failed"
)
def __init__(self):
super().__init__(
id="761babf0-a630-11ef-a03d-abceb082f58f",
description="This block deletes a tweet on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterDeleteTweetBlock.Input,
output_schema=TwitterDeleteTweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"delete_tweet": lambda *args, **kwargs: True},
)
@staticmethod
def delete_tweet(credentials: TwitterCredentials, tweet_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.delete_tweet(id=tweet_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
except Exception:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.delete_tweet(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterSearchRecentTweetsBlock(Block):
"""
Searches all public Tweets in Twitter history
"""
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
query: str = SchemaField(
description="Search query (up to 1024 characters)",
placeholder="Enter search query",
)
max_results: int = SchemaField(
description="Maximum number of results per page (10-500)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination: str | None = SchemaField(
description="Token for pagination",
default="",
placeholder="Enter pagination token",
advanced=True,
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
tweet_ids: list[str] = SchemaField(description="All Tweet IDs")
tweet_texts: list[str] = SchemaField(description="All Tweet texts")
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="53e5cf8e-a630-11ef-ba85-df6d666fa5d5",
description="This block searches all public Tweets in Twitter history.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterSearchRecentTweetsBlock.Input,
output_schema=TwitterSearchRecentTweetsBlock.Output,
test_input={
"query": "from:twitterapi #twitterapi",
"credentials": TEST_CREDENTIALS_INPUT,
"max_results": 2,
"start_time": "2024-12-14T18:30:00.000Z",
"end_time": "2024-12-17T18:30:00.000Z",
"since_id": None,
"until_id": None,
"sort_order": None,
"pagination": None,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("tweet_ids", ["1373001119480344583", "1372627771717869568"]),
(
"tweet_texts",
[
"Looking to get started with the Twitter API but new to APIs in general?",
"Thanks to everyone who joined and made today a great session!",
],
),
(
"data",
[
{
"id": "1373001119480344583",
"text": "Looking to get started with the Twitter API but new to APIs in general?",
},
{
"id": "1372627771717869568",
"text": "Thanks to everyone who joined and made today a great session!",
},
],
),
],
test_mock={
"search_tweets": lambda *args, **kwargs: (
["1373001119480344583", "1372627771717869568"],
[
"Looking to get started with the Twitter API but new to APIs in general?",
"Thanks to everyone who joined and made today a great session!",
],
[
{
"id": "1373001119480344583",
"text": "Looking to get started with the Twitter API but new to APIs in general?",
},
{
"id": "1372627771717869568",
"text": "Thanks to everyone who joined and made today a great session!",
},
],
{},
{},
None,
)
},
)
@staticmethod
def search_tweets(
credentials: TwitterCredentials,
query: str,
max_results: int,
start_time: datetime | None,
end_time: datetime | None,
since_id: str | None,
until_id: str | None,
sort_order: str | None,
pagination: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
# Building common params
params = (
TweetSearchBuilder()
.add_query(query)
.add_pagination(max_results, pagination)
.build()
)
# Adding expansions to params If required by the user
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
# Adding time window to params If required by the user
params = (
TweetDurationBuilder(params)
.add_start_time(start_time)
.add_end_time(end_time)
.add_since_id(since_id)
.add_until_id(until_id)
.add_sort_order(sort_order)
.build()
)
response = cast(Response, client.search_recent_tweets(**params))
if not response.data and not response.meta:
raise Exception("No tweets found")
meta = {}
tweet_ids = []
tweet_texts = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
return tweet_ids, tweet_texts, data, included, meta, next_token
raise Exception("No tweets found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, data, included, meta, next_token = self.search_tweets(
credentials,
input_data.query,
input_data.max_results,
input_data.start_time,
input_data.end_time,
input_data.since_id,
input_data.until_id,
input_data.sort_order,
input_data.pagination,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "tweet_ids", ids
if texts:
yield "tweet_texts", texts
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,222 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import TweetExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExcludesFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterGetQuoteTweetsBlock(Block):
"""
Gets quote tweets for a specified tweet ID
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to get quotes for",
placeholder="Enter tweet ID",
)
max_results: int | None = SchemaField(
description="Number of results to return (max 100)",
default=10,
advanced=True,
)
exclude: TweetExcludesFilter | None = SchemaField(
description="Types of tweets to exclude", advanced=True, default=None
)
pagination_token: str | None = SchemaField(
description="Token for pagination",
advanced=True,
default="",
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list = SchemaField(description="All Tweet IDs ")
texts: list = SchemaField(description="All Tweet texts")
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="9fbdd208-a630-11ef-9b97-ab7a3a695ca3",
description="This block gets quote tweets for a specific tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetQuoteTweetsBlock.Input,
output_schema=TwitterGetQuoteTweetsBlock.Output,
test_input={
"tweet_id": "1234567890",
"max_results": 2,
"pagination_token": None,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["12345", "67890"]),
("texts", ["Tweet 1", "Tweet 2"]),
(
"data",
[
{"id": "12345", "text": "Tweet 1"},
{"id": "67890", "text": "Tweet 2"},
],
),
],
test_mock={
"get_quote_tweets": lambda *args, **kwargs: (
["12345", "67890"],
["Tweet 1", "Tweet 2"],
[
{"id": "12345", "text": "Tweet 1"},
{"id": "67890", "text": "Tweet 2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_quote_tweets(
credentials: TwitterCredentials,
tweet_id: str,
max_results: int | None,
exclude: TweetExcludesFilter | None,
pagination_token: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": tweet_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"exclude": None if exclude == TweetExcludesFilter() else exclude,
"user_auth": False,
}
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_quote_tweets(**params))
meta = {}
tweet_ids = []
tweet_texts = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
return tweet_ids, tweet_texts, data, included, meta, next_token
raise Exception("No quote tweets found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, data, included, meta, next_token = self.get_quote_tweets(
credentials,
input_data.tweet_id,
input_data.max_results,
input_data.exclude,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if texts:
yield "texts", texts
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,363 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import UserExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
TweetFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterRetweetBlock(Block):
"""
Retweets a tweet on Twitter
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "tweet.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to retweet",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the retweet was successful")
error: str = SchemaField(description="Error message if the retweet failed")
def __init__(self):
super().__init__(
id="bd7b8d3a-a630-11ef-be96-6f4aa4c3c4f4",
description="This block retweets a tweet on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterRetweetBlock.Input,
output_schema=TwitterRetweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"retweet": lambda *args, **kwargs: True},
)
@staticmethod
def retweet(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.retweet(
tweet_id=tweet_id,
user_auth=False,
)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.retweet(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterRemoveRetweetBlock(Block):
"""
Removes a retweet on Twitter
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "tweet.write", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to remove retweet",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the retweet was successfully removed"
)
error: str = SchemaField(description="Error message if the removal failed")
def __init__(self):
super().__init__(
id="b6e663f0-a630-11ef-a7f0-8b9b0c542ff8",
description="This block removes a retweet on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterRemoveRetweetBlock.Input,
output_schema=TwitterRemoveRetweetBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"remove_retweet": lambda *args, **kwargs: True},
)
@staticmethod
def remove_retweet(
credentials: TwitterCredentials,
tweet_id: str,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unretweet(
source_tweet_id=tweet_id,
user_auth=False,
)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.remove_retweet(
credentials,
input_data.tweet_id,
)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetRetweetersBlock(Block):
"""
Gets information about who has retweeted a tweet
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="ID of the tweet to get retweeters for",
placeholder="Enter tweet ID",
)
max_results: int | None = SchemaField(
description="Maximum number of results per page (1-100)",
default=10,
placeholder="Enter max results",
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for pagination",
placeholder="Enter pagination token",
default="",
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list = SchemaField(description="List of user ids who retweeted")
names: list = SchemaField(description="List of user names who retweeted")
usernames: list = SchemaField(
description="List of user usernames who retweeted"
)
next_token: str = SchemaField(description="Token for next page of results")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="ad7aa6fa-a630-11ef-a6b0-e7ca640aa030",
description="This block gets information about who has retweeted a tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetRetweetersBlock.Input,
output_schema=TwitterGetRetweetersBlock.Output,
test_input={
"tweet_id": "1234567890",
"credentials": TEST_CREDENTIALS_INPUT,
"max_results": 1,
"pagination_token": "",
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["12345"]),
("names", ["Test User"]),
("usernames", ["testuser"]),
(
"data",
[{"id": "12345", "name": "Test User", "username": "testuser"}],
),
],
test_mock={
"get_retweeters": lambda *args, **kwargs: (
[{"id": "12345", "name": "Test User", "username": "testuser"}],
{},
{},
["12345"],
["Test User"],
["testuser"],
None,
)
},
)
@staticmethod
def get_retweeters(
credentials: TwitterCredentials,
tweet_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": tweet_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_retweeters(**params))
meta = {}
ids = []
names = []
usernames = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
ids = [str(user.id) for user in response.data]
names = [user.name for user in response.data]
usernames = [user.username for user in response.data]
return data, included, meta, ids, names, usernames, next_token
raise Exception("No retweeters found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
data, included, meta, ids, names, usernames, next_token = (
self.get_retweeters(
credentials,
input_data.tweet_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
)
if ids:
yield "ids", ids
if names:
yield "names", names
if usernames:
yield "usernames", usernames
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,757 @@
from datetime import datetime
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import (
TweetDurationBuilder,
TweetExpansionsBuilder,
)
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetTimeWindowInputs,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterGetUserMentionsBlock(Block):
"""
Returns Tweets where a single user is mentioned, just put that user id
"""
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
user_id: str = SchemaField(
description="Unique identifier of the user for whom to return Tweets mentioning the user",
placeholder="Enter user ID",
)
max_results: int | None = SchemaField(
description="Number of tweets to retrieve (5-100)",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for pagination", default="", advanced=True
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list[str] = SchemaField(description="List of Tweet IDs")
texts: list[str] = SchemaField(description="All Tweet texts")
userIds: list[str] = SchemaField(
description="List of user ids that mentioned the user"
)
userNames: list[str] = SchemaField(
description="List of user names that mentioned the user"
)
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="e01c890c-a630-11ef-9e20-37da24888bd0",
description="This block retrieves Tweets mentioning a specific user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetUserMentionsBlock.Input,
output_schema=TwitterGetUserMentionsBlock.Output,
test_input={
"user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
"max_results": 2,
"start_time": "2024-12-14T18:30:00.000Z",
"end_time": "2024-12-17T18:30:00.000Z",
"since_id": "",
"until_id": "",
"sort_order": None,
"pagination_token": None,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1373001119480344583", "1372627771717869568"]),
("texts", ["Test mention 1", "Test mention 2"]),
("userIds", ["67890", "67891"]),
("userNames", ["testuser1", "testuser2"]),
(
"data",
[
{"id": "1373001119480344583", "text": "Test mention 1"},
{"id": "1372627771717869568", "text": "Test mention 2"},
],
),
],
test_mock={
"get_mentions": lambda *args, **kwargs: (
["1373001119480344583", "1372627771717869568"],
["Test mention 1", "Test mention 2"],
["67890", "67891"],
["testuser1", "testuser2"],
[
{"id": "1373001119480344583", "text": "Test mention 1"},
{"id": "1372627771717869568", "text": "Test mention 2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_mentions(
credentials: TwitterCredentials,
user_id: str,
max_results: int | None,
start_time: datetime | None,
end_time: datetime | None,
since_id: str | None,
until_id: str | None,
sort_order: str | None,
pagination: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": user_id,
"max_results": max_results,
"pagination_token": None if pagination == "" else pagination,
"user_auth": False,
}
# Adding expansions to params If required by the user
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
# Adding time window to params If required by the user
params = (
TweetDurationBuilder(params)
.add_start_time(start_time)
.add_end_time(end_time)
.add_since_id(since_id)
.add_until_id(until_id)
.add_sort_order(sort_order)
.build()
)
response = cast(
Response,
client.get_users_mentions(**params),
)
if not response.data and not response.meta:
raise Exception("No tweets found")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
meta = response.meta or {}
next_token = meta.get("next_token", "")
tweet_ids = []
tweet_texts = []
user_ids = []
user_names = []
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
if "users" in included:
user_ids = [str(user["id"]) for user in included["users"]]
user_names = [user["username"] for user in included["users"]]
return (
tweet_ids,
tweet_texts,
user_ids,
user_names,
data,
included,
meta,
next_token,
)
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, user_ids, user_names, data, included, meta, next_token = (
self.get_mentions(
credentials,
input_data.user_id,
input_data.max_results,
input_data.start_time,
input_data.end_time,
input_data.since_id,
input_data.until_id,
input_data.sort_order,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
)
if ids:
yield "ids", ids
if texts:
yield "texts", texts
if user_ids:
yield "userIds", user_ids
if user_names:
yield "userNames", user_names
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetHomeTimelineBlock(Block):
"""
Returns a collection of the most recent Tweets and Retweets posted by you and users you follow
"""
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
max_results: int | None = SchemaField(
description="Number of tweets to retrieve (5-100)",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for pagination", default="", advanced=True
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list[str] = SchemaField(description="List of Tweet IDs")
texts: list[str] = SchemaField(description="All Tweet texts")
userIds: list[str] = SchemaField(
description="List of user ids that authored the tweets"
)
userNames: list[str] = SchemaField(
description="List of user names that authored the tweets"
)
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="d222a070-a630-11ef-a18a-3f52f76c6962",
description="This block retrieves the authenticated user's home timeline.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetHomeTimelineBlock.Input,
output_schema=TwitterGetHomeTimelineBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"max_results": 2,
"start_time": "2024-12-14T18:30:00.000Z",
"end_time": "2024-12-17T18:30:00.000Z",
"since_id": None,
"until_id": None,
"sort_order": None,
"pagination_token": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1373001119480344583", "1372627771717869568"]),
("texts", ["Test tweet 1", "Test tweet 2"]),
("userIds", ["67890", "67891"]),
("userNames", ["testuser1", "testuser2"]),
(
"data",
[
{"id": "1373001119480344583", "text": "Test tweet 1"},
{"id": "1372627771717869568", "text": "Test tweet 2"},
],
),
],
test_mock={
"get_timeline": lambda *args, **kwargs: (
["1373001119480344583", "1372627771717869568"],
["Test tweet 1", "Test tweet 2"],
["67890", "67891"],
["testuser1", "testuser2"],
[
{"id": "1373001119480344583", "text": "Test tweet 1"},
{"id": "1372627771717869568", "text": "Test tweet 2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_timeline(
credentials: TwitterCredentials,
max_results: int | None,
start_time: datetime | None,
end_time: datetime | None,
since_id: str | None,
until_id: str | None,
sort_order: str | None,
pagination: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"max_results": max_results,
"pagination_token": None if pagination == "" else pagination,
"user_auth": False,
}
# Adding expansions to params If required by the user
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
# Adding time window to params If required by the user
params = (
TweetDurationBuilder(params)
.add_start_time(start_time)
.add_end_time(end_time)
.add_since_id(since_id)
.add_until_id(until_id)
.add_sort_order(sort_order)
.build()
)
response = cast(
Response,
client.get_home_timeline(**params),
)
if not response.data and not response.meta:
raise Exception("No tweets found")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
meta = response.meta or {}
next_token = meta.get("next_token", "")
tweet_ids = []
tweet_texts = []
user_ids = []
user_names = []
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
if "users" in included:
user_ids = [str(user["id"]) for user in included["users"]]
user_names = [user["username"] for user in included["users"]]
return (
tweet_ids,
tweet_texts,
user_ids,
user_names,
data,
included,
meta,
next_token,
)
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, user_ids, user_names, data, included, meta, next_token = (
self.get_timeline(
credentials,
input_data.max_results,
input_data.start_time,
input_data.end_time,
input_data.since_id,
input_data.until_id,
input_data.sort_order,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
)
if ids:
yield "ids", ids
if texts:
yield "texts", texts
if user_ids:
yield "userIds", user_ids
if user_names:
yield "userNames", user_names
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetUserTweetsBlock(Block):
"""
Returns Tweets composed by a single user, specified by the requested user ID
"""
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
user_id: str = SchemaField(
description="Unique identifier of the Twitter account (user ID) for whom to return results",
placeholder="Enter user ID",
)
max_results: int | None = SchemaField(
description="Number of tweets to retrieve (5-100)",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for pagination", default="", advanced=True
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list[str] = SchemaField(description="List of Tweet IDs")
texts: list[str] = SchemaField(description="All Tweet texts")
userIds: list[str] = SchemaField(
description="List of user ids that authored the tweets"
)
userNames: list[str] = SchemaField(
description="List of user names that authored the tweets"
)
next_token: str = SchemaField(description="Next token for pagination")
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(
description="Provides metadata such as pagination info (next_token) or result counts"
)
# error
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="c44c3ef2-a630-11ef-9ff7-eb7b5ea3a5cb",
description="This block retrieves Tweets composed by a single user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetUserTweetsBlock.Input,
output_schema=TwitterGetUserTweetsBlock.Output,
test_input={
"user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
"max_results": 2,
"start_time": "2024-12-14T18:30:00.000Z",
"end_time": "2024-12-17T18:30:00.000Z",
"since_id": None,
"until_id": None,
"sort_order": None,
"pagination_token": None,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1373001119480344583", "1372627771717869568"]),
("texts", ["Test tweet 1", "Test tweet 2"]),
("userIds", ["67890", "67891"]),
("userNames", ["testuser1", "testuser2"]),
(
"data",
[
{"id": "1373001119480344583", "text": "Test tweet 1"},
{"id": "1372627771717869568", "text": "Test tweet 2"},
],
),
],
test_mock={
"get_user_tweets": lambda *args, **kwargs: (
["1373001119480344583", "1372627771717869568"],
["Test tweet 1", "Test tweet 2"],
["67890", "67891"],
["testuser1", "testuser2"],
[
{"id": "1373001119480344583", "text": "Test tweet 1"},
{"id": "1372627771717869568", "text": "Test tweet 2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_user_tweets(
credentials: TwitterCredentials,
user_id: str,
max_results: int | None,
start_time: datetime | None,
end_time: datetime | None,
since_id: str | None,
until_id: str | None,
sort_order: str | None,
pagination: str | None,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": user_id,
"max_results": max_results,
"pagination_token": None if pagination == "" else pagination,
"user_auth": False,
}
# Adding expansions to params If required by the user
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
# Adding time window to params If required by the user
params = (
TweetDurationBuilder(params)
.add_start_time(start_time)
.add_end_time(end_time)
.add_since_id(since_id)
.add_until_id(until_id)
.add_sort_order(sort_order)
.build()
)
response = cast(
Response,
client.get_users_tweets(**params),
)
if not response.data and not response.meta:
raise Exception("No tweets found")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
meta = response.meta or {}
next_token = meta.get("next_token", "")
tweet_ids = []
tweet_texts = []
user_ids = []
user_names = []
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
if "users" in included:
user_ids = [str(user["id"]) for user in included["users"]]
user_names = [user["username"] for user in included["users"]]
return (
tweet_ids,
tweet_texts,
user_ids,
user_names,
data,
included,
meta,
next_token,
)
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, user_ids, user_names, data, included, meta, next_token = (
self.get_user_tweets(
credentials,
input_data.user_id,
input_data.max_results,
input_data.start_time,
input_data.end_time,
input_data.since_id,
input_data.until_id,
input_data.sort_order,
input_data.pagination_token,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
)
if ids:
yield "ids", ids
if texts:
yield "texts", texts
if user_ids:
yield "userIds", user_ids
if user_names:
yield "userNames", user_names
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,361 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import TweetExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
ExpansionFilter,
TweetExpansionInputs,
TweetFieldsFilter,
TweetMediaFieldsFilter,
TweetPlaceFieldsFilter,
TweetPollFieldsFilter,
TweetUserFieldsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterGetTweetBlock(Block):
"""
Returns information about a single Tweet specified by the requested ID
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
tweet_id: str = SchemaField(
description="Unique identifier of the Tweet to request (ex: 1460323737035677698)",
placeholder="Enter tweet ID",
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
id: str = SchemaField(description="Tweet ID")
text: str = SchemaField(description="Tweet text")
userId: str = SchemaField(description="ID of the tweet author")
userName: str = SchemaField(description="Username of the tweet author")
# Complete Outputs for advanced use
data: dict = SchemaField(description="Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(description="Metadata about the tweet")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="f5155c3a-a630-11ef-9cc1-a309988b4d92",
description="This block retrieves information about a specific Tweet.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetTweetBlock.Input,
output_schema=TwitterGetTweetBlock.Output,
test_input={
"tweet_id": "1460323737035677698",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", "1460323737035677698"),
("text", "Test tweet content"),
("userId", "12345"),
("userName", "testuser"),
("data", {"id": "1460323737035677698", "text": "Test tweet content"}),
("included", {"users": [{"id": "12345", "username": "testuser"}]}),
("meta", {"result_count": 1}),
],
test_mock={
"get_tweet": lambda *args, **kwargs: (
{"id": "1460323737035677698", "text": "Test tweet content"},
{"users": [{"id": "12345", "username": "testuser"}]},
{"result_count": 1},
"12345",
"testuser",
)
},
)
@staticmethod
def get_tweet(
credentials: TwitterCredentials,
tweet_id: str,
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {"id": tweet_id, "user_auth": False}
# Adding expansions to params If required by the user
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_tweet(**params))
meta = {}
user_id = ""
user_name = ""
if response.meta:
meta = response.meta
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_dict(response.data)
if included and "users" in included:
user_id = str(included["users"][0]["id"])
user_name = included["users"][0]["username"]
if response.data:
return data, included, meta, user_id, user_name
raise Exception("Tweet not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
tweet_data, included, meta, user_id, user_name = self.get_tweet(
credentials,
input_data.tweet_id,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
yield "id", str(tweet_data["id"])
yield "text", tweet_data["text"]
if user_id:
yield "userId", user_id
if user_name:
yield "userName", user_name
yield "data", tweet_data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetTweetsBlock(Block):
"""
Returns information about multiple Tweets specified by the requested IDs
"""
class Input(TweetExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["tweet.read", "users.read", "offline.access"]
)
tweet_ids: list[str] = SchemaField(
description="List of Tweet IDs to request (up to 100)",
placeholder="Enter tweet IDs",
)
class Output(BlockSchema):
# Common Outputs that user commonly uses
ids: list[str] = SchemaField(description="All Tweet IDs")
texts: list[str] = SchemaField(description="All Tweet texts")
userIds: list[str] = SchemaField(
description="List of user ids that authored the tweets"
)
userNames: list[str] = SchemaField(
description="List of user names that authored the tweets"
)
# Complete Outputs for advanced use
data: list[dict] = SchemaField(description="Complete Tweet data")
included: dict = SchemaField(
description="Additional data that you have requested (Optional) via Expansions field"
)
meta: dict = SchemaField(description="Metadata about the tweets")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="e7cc5420-a630-11ef-bfaf-13bdd8096a51",
description="This block retrieves information about multiple Tweets.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetTweetsBlock.Input,
output_schema=TwitterGetTweetsBlock.Output,
test_input={
"tweet_ids": ["1460323737035677698"],
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"media_fields": None,
"place_fields": None,
"poll_fields": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1460323737035677698"]),
("texts", ["Test tweet content"]),
("userIds", ["67890"]),
("userNames", ["testuser1"]),
("data", [{"id": "1460323737035677698", "text": "Test tweet content"}]),
("included", {"users": [{"id": "67890", "username": "testuser1"}]}),
("meta", {"result_count": 1}),
],
test_mock={
"get_tweets": lambda *args, **kwargs: (
["1460323737035677698"], # ids
["Test tweet content"], # texts
["67890"], # user_ids
["testuser1"], # user_names
[
{"id": "1460323737035677698", "text": "Test tweet content"}
], # data
{"users": [{"id": "67890", "username": "testuser1"}]}, # included
{"result_count": 1}, # meta
)
},
)
@staticmethod
def get_tweets(
credentials: TwitterCredentials,
tweet_ids: list[str],
expansions: ExpansionFilter | None,
media_fields: TweetMediaFieldsFilter | None,
place_fields: TweetPlaceFieldsFilter | None,
poll_fields: TweetPollFieldsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {"ids": tweet_ids, "user_auth": False}
# Adding expansions to params If required by the user
params = (
TweetExpansionsBuilder(params)
.add_expansions(expansions)
.add_media_fields(media_fields)
.add_place_fields(place_fields)
.add_poll_fields(poll_fields)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_tweets(**params))
if not response.data and not response.meta:
raise Exception("No tweets found")
tweet_ids = []
tweet_texts = []
user_ids = []
user_names = []
meta = {}
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
tweet_ids = [str(tweet.id) for tweet in response.data]
tweet_texts = [tweet.text for tweet in response.data]
if included and "users" in included:
for user in included["users"]:
user_ids.append(str(user["id"]))
user_names.append(user["username"])
if response.meta:
meta = response.meta
return tweet_ids, tweet_texts, user_ids, user_names, data, included, meta
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, texts, user_ids, user_names, data, included, meta = self.get_tweets(
credentials,
input_data.tweet_ids,
input_data.expansions,
input_data.media_fields,
input_data.place_fields,
input_data.poll_fields,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if texts:
yield "texts", texts
if user_ids:
yield "userIds", user_ids
if user_names:
yield "userNames", user_names
if data:
yield "data", data
if included:
yield "included", included
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,305 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import UserExpansionsBuilder
from backend.blocks.twitter._serializer import IncludesSerializer
from backend.blocks.twitter._types import (
TweetFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterUnblockUserBlock(Block):
"""
Unblock a specific user on Twitter. The request succeeds with no action when the user sends a request to a user they're not blocking or have already unblocked.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["block.write", "users.read", "offline.access"]
)
target_user_id: str = SchemaField(
description="The user ID of the user that you would like to unblock",
placeholder="Enter target user ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the unblock was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="0f1b6570-a631-11ef-a3ea-230cbe9650dd",
description="This block unblocks a specific user on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnblockUserBlock.Input,
output_schema=TwitterUnblockUserBlock.Output,
test_input={
"target_user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"unblock_user": lambda *args, **kwargs: True},
)
@staticmethod
def unblock_user(credentials: TwitterCredentials, target_user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unblock(target_user_id=target_user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unblock_user(credentials, input_data.target_user_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetBlockedUsersBlock(Block):
"""
Get a list of users who are blocked by the authenticating user
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "offline.access", "block.read"]
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (1-1000, default 100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for retrieving next/previous page of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
user_ids: list[str] = SchemaField(description="List of blocked user IDs")
usernames_: list[str] = SchemaField(description="List of blocked usernames")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata including pagination info")
next_token: str = SchemaField(description="Next token for pagination")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="05f409e8-a631-11ef-ae89-93de863ee30d",
description="This block retrieves a list of users blocked by the authenticating user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetBlockedUsersBlock.Input,
output_schema=TwitterGetBlockedUsersBlock.Output,
test_input={
"max_results": 10,
"pagination_token": "",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("user_ids", ["12345", "67890"]),
("usernames_", ["testuser1", "testuser2"]),
],
test_mock={
"get_blocked_users": lambda *args, **kwargs: (
{}, # included
{}, # meta
["12345", "67890"], # user_ids
["testuser1", "testuser2"], # usernames
None, # next_token
)
},
)
@staticmethod
def get_blocked_users(
credentials: TwitterCredentials,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_blocked(**params))
meta = {}
user_ids = []
usernames = []
next_token = None
included = IncludesSerializer.serialize(response.includes)
if response.data:
for user in response.data:
user_ids.append(str(user.id))
usernames.append(user.username)
if response.meta:
meta = response.meta
if "next_token" in meta:
next_token = meta["next_token"]
if user_ids and usernames:
return included, meta, user_ids, usernames, next_token
else:
raise tweepy.TweepyException("No blocked users found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
included, meta, user_ids, usernames, next_token = self.get_blocked_users(
credentials,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if user_ids:
yield "user_ids", user_ids
if usernames:
yield "usernames_", usernames
if included:
yield "included", included
if meta:
yield "meta", meta
if next_token:
yield "next_token", next_token
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterBlockUserBlock(Block):
"""
Block a specific user on Twitter
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["block.write", "users.read", "offline.access"]
)
target_user_id: str = SchemaField(
description="The user ID of the user that you would like to block",
placeholder="Enter target user ID",
)
class Output(BlockSchema):
success: bool = SchemaField(description="Whether the block was successful")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="fc258b94-a630-11ef-abc3-df050b75b816",
description="This block blocks a specific user on Twitter.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterBlockUserBlock.Input,
output_schema=TwitterBlockUserBlock.Output,
test_input={
"target_user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"block_user": lambda *args, **kwargs: True},
)
@staticmethod
def block_user(credentials: TwitterCredentials, target_user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.block(target_user_id=target_user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.block_user(credentials, input_data.target_user_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,510 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import UserExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
TweetFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterUnfollowUserBlock(Block):
"""
Allows a user to unfollow another user specified by target user ID.
The request succeeds with no action when the authenticated user sends a request to a user they're not following or have already unfollowed.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "users.write", "follows.write", "offline.access"]
)
target_user_id: str = SchemaField(
description="The user ID of the user that you would like to unfollow",
placeholder="Enter target user ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the unfollow action was successful"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="37e386a4-a631-11ef-b7bd-b78204b35fa4",
description="This block unfollows a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnfollowUserBlock.Input,
output_schema=TwitterUnfollowUserBlock.Output,
test_input={
"target_user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"unfollow_user": lambda *args, **kwargs: True},
)
@staticmethod
def unfollow_user(credentials: TwitterCredentials, target_user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unfollow_user(target_user_id=target_user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unfollow_user(credentials, input_data.target_user_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterFollowUserBlock(Block):
"""
Allows a user to follow another user specified by target user ID. If the target user does not have public Tweets,
this endpoint will send a follow request. The request succeeds with no action when the authenticated user sends a
request to a user they're already following, or if they're sending a follower request to a user that does not have
public Tweets.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "users.write", "follows.write", "offline.access"]
)
target_user_id: str = SchemaField(
description="The user ID of the user that you would like to follow",
placeholder="Enter target user ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the follow action was successful"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="1aae6a5e-a631-11ef-a090-435900c6d429",
description="This block follows a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterFollowUserBlock.Input,
output_schema=TwitterFollowUserBlock.Output,
test_input={
"target_user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("success", True)],
test_mock={"follow_user": lambda *args, **kwargs: True},
)
@staticmethod
def follow_user(credentials: TwitterCredentials, target_user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.follow_user(target_user_id=target_user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.follow_user(credentials, input_data.target_user_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetFollowersBlock(Block):
"""
Retrieves a list of followers for a specified Twitter user ID
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "offline.access", "follows.read"]
)
target_user_id: str = SchemaField(
description="The user ID whose followers you would like to retrieve",
placeholder="Enter target user ID",
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (1-1000, default 100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for retrieving next/previous page of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
ids: list[str] = SchemaField(description="List of follower user IDs")
usernames: list[str] = SchemaField(description="List of follower usernames")
next_token: str = SchemaField(description="Next token for pagination")
data: list[dict] = SchemaField(description="Complete user data for followers")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata including pagination info")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="30f66410-a631-11ef-8fe7-d7f888b4f43c",
description="This block retrieves followers of a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetFollowersBlock.Input,
output_schema=TwitterGetFollowersBlock.Output,
test_input={
"target_user_id": "12345",
"max_results": 1,
"pagination_token": "",
"expansions": None,
"tweet_fields": None,
"user_fields": None,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1234567890"]),
("usernames", ["testuser"]),
("data", [{"id": "1234567890", "username": "testuser"}]),
],
test_mock={
"get_followers": lambda *args, **kwargs: (
["1234567890"],
["testuser"],
[{"id": "1234567890", "username": "testuser"}],
{},
{},
None,
)
},
)
@staticmethod
def get_followers(
credentials: TwitterCredentials,
target_user_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": target_user_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_users_followers(**params))
meta = {}
follower_ids = []
follower_usernames = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
follower_ids = [str(user.id) for user in response.data]
follower_usernames = [user.username for user in response.data]
return (
follower_ids,
follower_usernames,
data,
included,
meta,
next_token,
)
raise Exception("Followers not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, usernames, data, includes, meta, next_token = self.get_followers(
credentials,
input_data.target_user_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if usernames:
yield "usernames", usernames
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if includes:
yield "includes", includes
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetFollowingBlock(Block):
"""
Retrieves a list of users that a specified Twitter user ID is following
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "offline.access", "follows.read"]
)
target_user_id: str = SchemaField(
description="The user ID whose following you would like to retrieve",
placeholder="Enter target user ID",
)
max_results: int | None = SchemaField(
description="Maximum number of results to return (1-1000, default 100)",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token for retrieving next/previous page of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
ids: list[str] = SchemaField(description="List of following user IDs")
usernames: list[str] = SchemaField(description="List of following usernames")
next_token: str = SchemaField(description="Next token for pagination")
data: list[dict] = SchemaField(description="Complete user data for following")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata including pagination info")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="264a399c-a631-11ef-a97d-bfde4ca91173",
description="This block retrieves the users that a specified Twitter user is following.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetFollowingBlock.Input,
output_schema=TwitterGetFollowingBlock.Output,
test_input={
"target_user_id": "12345",
"max_results": 1,
"pagination_token": None,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["1234567890"]),
("usernames", ["testuser"]),
("data", [{"id": "1234567890", "username": "testuser"}]),
],
test_mock={
"get_following": lambda *args, **kwargs: (
["1234567890"],
["testuser"],
[{"id": "1234567890", "username": "testuser"}],
{},
{},
None,
)
},
)
@staticmethod
def get_following(
credentials: TwitterCredentials,
target_user_id: str,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": target_user_id,
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_users_following(**params))
meta = {}
following_ids = []
following_usernames = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
following_ids = [str(user.id) for user in response.data]
following_usernames = [user.username for user in response.data]
return (
following_ids,
following_usernames,
data,
included,
meta,
next_token,
)
raise Exception("Following not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, usernames, data, includes, meta, next_token = self.get_following(
credentials,
input_data.target_user_id,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if usernames:
yield "usernames", usernames
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if includes:
yield "includes", includes
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,328 @@
from typing import cast
import tweepy
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import UserExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
TweetFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class TwitterUnmuteUserBlock(Block):
"""
Allows a user to unmute another user specified by target user ID.
The request succeeds with no action when the user sends a request to a user they're not muting or have already unmuted.
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "users.write", "offline.access"]
)
target_user_id: str = SchemaField(
description="The user ID of the user that you would like to unmute",
placeholder="Enter target user ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the unmute action was successful"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="40458504-a631-11ef-940b-eff92be55422",
description="This block unmutes a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterUnmuteUserBlock.Input,
output_schema=TwitterUnmuteUserBlock.Output,
test_input={
"target_user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"unmute_user": lambda *args, **kwargs: True},
)
@staticmethod
def unmute_user(credentials: TwitterCredentials, target_user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.unmute(target_user_id=target_user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.unmute_user(credentials, input_data.target_user_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterGetMutedUsersBlock(Block):
"""
Returns a list of users who are muted by the authenticating user
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "offline.access"]
)
max_results: int | None = SchemaField(
description="The maximum number of results to be returned per page (1-1000). Default is 100.",
placeholder="Enter max results",
default=10,
advanced=True,
)
pagination_token: str | None = SchemaField(
description="Token to request next/previous page of results",
placeholder="Enter pagination token",
default="",
advanced=True,
)
class Output(BlockSchema):
ids: list[str] = SchemaField(description="List of muted user IDs")
usernames: list[str] = SchemaField(description="List of muted usernames")
next_token: str = SchemaField(description="Next token for pagination")
data: list[dict] = SchemaField(description="Complete user data for muted users")
includes: dict = SchemaField(
description="Additional data requested via expansions"
)
meta: dict = SchemaField(description="Metadata including pagination info")
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="475024da-a631-11ef-9ccd-f724b8b03cda",
description="This block gets a list of users muted by the authenticating user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetMutedUsersBlock.Input,
output_schema=TwitterGetMutedUsersBlock.Output,
test_input={
"max_results": 2,
"pagination_token": "",
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["12345", "67890"]),
("usernames", ["testuser1", "testuser2"]),
(
"data",
[
{"id": "12345", "username": "testuser1"},
{"id": "67890", "username": "testuser2"},
],
),
],
test_mock={
"get_muted_users": lambda *args, **kwargs: (
["12345", "67890"],
["testuser1", "testuser2"],
[
{"id": "12345", "username": "testuser1"},
{"id": "67890", "username": "testuser2"},
],
{},
{},
None,
)
},
)
@staticmethod
def get_muted_users(
credentials: TwitterCredentials,
max_results: int | None,
pagination_token: str | None,
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"max_results": max_results,
"pagination_token": (
None if pagination_token == "" else pagination_token
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_muted(**params))
meta = {}
user_ids = []
usernames = []
next_token = None
if response.meta:
meta = response.meta
next_token = meta.get("next_token")
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
user_ids = [str(item.id) for item in response.data]
usernames = [item.username for item in response.data]
return user_ids, usernames, data, included, meta, next_token
raise Exception("Muted users not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
ids, usernames, data, includes, meta, next_token = self.get_muted_users(
credentials,
input_data.max_results,
input_data.pagination_token,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if usernames:
yield "usernames", usernames
if next_token:
yield "next_token", next_token
if data:
yield "data", data
if includes:
yield "includes", includes
if meta:
yield "meta", meta
except Exception as e:
yield "error", handle_tweepy_exception(e)
class TwitterMuteUserBlock(Block):
"""
Allows a user to mute another user specified by target user ID
"""
class Input(BlockSchema):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "users.write", "offline.access"]
)
target_user_id: str = SchemaField(
description="The user ID of the user that you would like to mute",
placeholder="Enter target user ID",
)
class Output(BlockSchema):
success: bool = SchemaField(
description="Whether the mute action was successful"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="4d1919d0-a631-11ef-90ab-3b73af9ce8f1",
description="This block mutes a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterMuteUserBlock.Input,
output_schema=TwitterMuteUserBlock.Output,
test_input={
"target_user_id": "12345",
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("success", True),
],
test_mock={"mute_user": lambda *args, **kwargs: True},
)
@staticmethod
def mute_user(credentials: TwitterCredentials, target_user_id: str):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
client.mute(target_user_id=target_user_id, user_auth=False)
return True
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
success = self.mute_user(credentials, input_data.target_user_id)
yield "success", success
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -0,0 +1,383 @@
from typing import Literal, Union, cast
import tweepy
from pydantic import BaseModel
from tweepy.client import Response
from backend.blocks.twitter._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TwitterCredentials,
TwitterCredentialsField,
TwitterCredentialsInput,
)
from backend.blocks.twitter._builders import UserExpansionsBuilder
from backend.blocks.twitter._serializer import (
IncludesSerializer,
ResponseDataSerializer,
)
from backend.blocks.twitter._types import (
TweetFieldsFilter,
TweetUserFieldsFilter,
UserExpansionInputs,
UserExpansionsFilter,
)
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
class UserId(BaseModel):
discriminator: Literal["user_id"]
user_id: str = SchemaField(description="The ID of the user to lookup", default="")
class Username(BaseModel):
discriminator: Literal["username"]
username: str = SchemaField(
description="The Twitter username (handle) of the user", default=""
)
class TwitterGetUserBlock(Block):
"""
Gets information about a single Twitter user specified by ID or username
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "offline.access"]
)
identifier: Union[UserId, Username] = SchemaField(
discriminator="discriminator",
description="Choose whether to identify the user by their unique Twitter ID or by their username",
advanced=False,
)
class Output(BlockSchema):
# Common outputs
id: str = SchemaField(description="User ID")
username_: str = SchemaField(description="User username")
name_: str = SchemaField(description="User name")
# Complete outputs
data: dict = SchemaField(description="Complete user data")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="5446db8e-a631-11ef-812a-cf315d373ee9",
description="This block retrieves information about a specified Twitter user.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetUserBlock.Input,
output_schema=TwitterGetUserBlock.Output,
test_input={
"identifier": {"discriminator": "username", "username": "twitter"},
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("id", "783214"),
("username_", "twitter"),
("name_", "Twitter"),
(
"data",
{
"user": {
"id": "783214",
"username": "twitter",
"name": "Twitter",
}
},
),
],
test_mock={
"get_user": lambda *args, **kwargs: (
{
"user": {
"id": "783214",
"username": "twitter",
"name": "Twitter",
}
},
{},
"twitter",
"783214",
"Twitter",
)
},
)
@staticmethod
def get_user(
credentials: TwitterCredentials,
identifier: Union[UserId, Username],
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"id": identifier.user_id if isinstance(identifier, UserId) else None,
"username": (
identifier.username if isinstance(identifier, Username) else None
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_user(**params))
username = ""
id = ""
name = ""
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_dict(response.data)
if response.data:
username = response.data.username
id = str(response.data.id)
name = response.data.name
if username and id:
return data, included, username, id, name
else:
raise tweepy.TweepyException("User not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
data, included, username, id, name = self.get_user(
credentials,
input_data.identifier,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if id:
yield "id", id
if username:
yield "username_", username
if name:
yield "name_", name
if data:
yield "data", data
if included:
yield "included", included
except Exception as e:
yield "error", handle_tweepy_exception(e)
class UserIdList(BaseModel):
discriminator: Literal["user_id_list"]
user_ids: list[str] = SchemaField(
description="List of user IDs to lookup (max 100)",
placeholder="Enter user IDs",
default=[],
advanced=False,
)
class UsernameList(BaseModel):
discriminator: Literal["username_list"]
usernames: list[str] = SchemaField(
description="List of Twitter usernames/handles to lookup (max 100)",
placeholder="Enter usernames",
default=[],
advanced=False,
)
class TwitterGetUsersBlock(Block):
"""
Gets information about multiple Twitter users specified by IDs or usernames
"""
class Input(UserExpansionInputs):
credentials: TwitterCredentialsInput = TwitterCredentialsField(
["users.read", "offline.access"]
)
identifier: Union[UserIdList, UsernameList] = SchemaField(
discriminator="discriminator",
description="Choose whether to identify users by their unique Twitter IDs or by their usernames",
advanced=False,
)
class Output(BlockSchema):
# Common outputs
ids: list[str] = SchemaField(description="User IDs")
usernames_: list[str] = SchemaField(description="User usernames")
names_: list[str] = SchemaField(description="User names")
# Complete outputs
data: list[dict] = SchemaField(description="Complete users data")
included: dict = SchemaField(
description="Additional data requested via expansions"
)
error: str = SchemaField(description="Error message if the request failed")
def __init__(self):
super().__init__(
id="5abc857c-a631-11ef-8cfc-f7b79354f7a1",
description="This block retrieves information about multiple Twitter users.",
categories={BlockCategory.SOCIAL},
input_schema=TwitterGetUsersBlock.Input,
output_schema=TwitterGetUsersBlock.Output,
test_input={
"identifier": {
"discriminator": "username_list",
"usernames": ["twitter", "twitterdev"],
},
"credentials": TEST_CREDENTIALS_INPUT,
"expansions": None,
"tweet_fields": None,
"user_fields": None,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("ids", ["783214", "2244994945"]),
("usernames_", ["twitter", "twitterdev"]),
("names_", ["Twitter", "Twitter Dev"]),
(
"data",
[
{"id": "783214", "username": "twitter", "name": "Twitter"},
{
"id": "2244994945",
"username": "twitterdev",
"name": "Twitter Dev",
},
],
),
],
test_mock={
"get_users": lambda *args, **kwargs: (
[
{"id": "783214", "username": "twitter", "name": "Twitter"},
{
"id": "2244994945",
"username": "twitterdev",
"name": "Twitter Dev",
},
],
{},
["twitter", "twitterdev"],
["783214", "2244994945"],
["Twitter", "Twitter Dev"],
)
},
)
@staticmethod
def get_users(
credentials: TwitterCredentials,
identifier: Union[UserIdList, UsernameList],
expansions: UserExpansionsFilter | None,
tweet_fields: TweetFieldsFilter | None,
user_fields: TweetUserFieldsFilter | None,
):
try:
client = tweepy.Client(
bearer_token=credentials.access_token.get_secret_value()
)
params = {
"ids": (
",".join(identifier.user_ids)
if isinstance(identifier, UserIdList)
else None
),
"usernames": (
",".join(identifier.usernames)
if isinstance(identifier, UsernameList)
else None
),
"user_auth": False,
}
params = (
UserExpansionsBuilder(params)
.add_expansions(expansions)
.add_tweet_fields(tweet_fields)
.add_user_fields(user_fields)
.build()
)
response = cast(Response, client.get_users(**params))
usernames = []
ids = []
names = []
included = IncludesSerializer.serialize(response.includes)
data = ResponseDataSerializer.serialize_list(response.data)
if response.data:
for user in response.data:
usernames.append(user.username)
ids.append(str(user.id))
names.append(user.name)
if usernames and ids:
return data, included, usernames, ids, names
else:
raise tweepy.TweepyException("Users not found")
except tweepy.TweepyException:
raise
def run(
self,
input_data: Input,
*,
credentials: TwitterCredentials,
**kwargs,
) -> BlockOutput:
try:
data, included, usernames, ids, names = self.get_users(
credentials,
input_data.identifier,
input_data.expansions,
input_data.tweet_fields,
input_data.user_fields,
)
if ids:
yield "ids", ids
if usernames:
yield "usernames_", usernames
if names:
yield "names_", names
if data:
yield "data", data
if included:
yield "included", included
except Exception as e:
yield "error", handle_tweepy_exception(e)

View File

@ -61,6 +61,9 @@ class BlockCategory(Enum):
HARDWARE = "Block that interacts with hardware."
AGENT = "Block that interacts with other agents."
CRM = "Block that interacts with CRM services."
SAFETY = (
"Block that provides AI safety mechanisms such as detecting harmful content"
)
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}

View File

@ -51,6 +51,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.LLAMA3_1_405B: 1,
LlmModel.LLAMA3_1_70B: 1,
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_2: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,

View File

@ -136,6 +136,7 @@ async def create_graph_execution(
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
@ -163,6 +164,7 @@ async def create_graph_execution(
]
},
"userId": user_id,
"agentPresetId": preset_id,
},
include=GRAPH_EXECUTION_INCLUDE,
)

View File

@ -6,7 +6,13 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Type
import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
@ -529,7 +535,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph(
graph_id: str,
version: int | None = None,
template: bool = False,
user_id: str | None = None,
for_export: bool = False,
) -> GraphModel | None:
@ -543,21 +548,36 @@ async def get_graph(
where_clause: AgentGraphWhereInput = {
"id": graph_id,
}
if version is not None:
where_clause["version"] = version
elif not template:
else:
where_clause["isActive"] = True
# TODO: Fix hack workaround to get adding store agents to work
if user_id is not None and not template:
where_clause["userId"] = user_id
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
return GraphModel.from_db(graph, for_export) if graph else None
# The Graph has to be owned by the user or a store listing.
if (
graph is None
or graph.userId != user_id
and not (
await StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
agentId=graph_id,
agentVersion=version or graph.version,
isDeleted=False,
StoreListing={"is": {"isApproved": True}},
)
)
)
):
return None
return GraphModel.from_db(graph, for_export)
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
@ -611,9 +631,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx:
await __create_graph(tx, graph, user_id)
if created_graph := await get_graph(
graph.id, graph.version, graph.is_template, user_id=user_id
):
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")

View File

@ -139,6 +139,8 @@ def SchemaField(
exclude: bool = False,
hidden: Optional[bool] = None,
depends_on: list[str] | None = None,
image_upload: Optional[bool] = None,
image_output: Optional[bool] = None,
**kwargs,
) -> T:
if default is PydanticUndefined and default_factory is None:
@ -154,6 +156,8 @@ def SchemaField(
"advanced": advanced,
"hidden": hidden,
"depends_on": depends_on,
"image_upload": image_upload,
"image_output": image_output,
}.items()
if v is not None
}
@ -222,6 +226,7 @@ class OAuthState(BaseModel):
token: str
provider: str
expires_at: int
code_verifier: Optional[str] = None
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
scopes: list[str]

View File

@ -780,7 +780,8 @@ class ExecutionManager(AppService):
graph_id: str,
data: BlockInput,
user_id: str,
graph_version: int | None = None,
graph_version: int,
preset_id: str | None = None,
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
@ -829,6 +830,7 @@ class ExecutionManager(AppService):
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
starting_node_execs = []

View File

@ -63,7 +63,10 @@ def execute_graph(**kwargs):
try:
log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution(
args.graph_id, args.input_data, args.user_id
graph_id=args.graph_id,
data=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
)
except Exception as e:
logger.exception(f"Error executing graph {args.graph_id}: {e}")

View File

@ -1,6 +1,8 @@
import base64
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from pydantic import SecretStr
@ -91,6 +93,34 @@ open_router_credentials = APIKeyCredentials(
title="Use Credits for Open Router",
expires_at=None,
)
fal_credentials = APIKeyCredentials(
id="6c0f5bd0-9008-4638-9d79-4b40b631803e",
provider="fal",
api_key=SecretStr(settings.secrets.fal_api_key),
title="Use Credits for FAL",
expires_at=None,
)
exa_credentials = APIKeyCredentials(
id="96153e04-9c6c-4486-895f-5bb683b1ecec",
provider="exa",
api_key=SecretStr(settings.secrets.exa_api_key),
title="Use Credits for Exa search",
expires_at=None,
)
e2b_credentials = APIKeyCredentials(
id="78d19fd7-4d59-4a16-8277-3ce310acf2b7",
provider="e2b",
api_key=SecretStr(settings.secrets.e2b_api_key),
title="Use Credits for E2B",
expires_at=None,
)
nvidia_credentials = APIKeyCredentials(
id="96b83908-2789-4dec-9968-18f0ece4ceb3",
provider="nvidia",
api_key=SecretStr(settings.secrets.nvidia_api_key),
title="Use Credits for Nvidia",
expires_at=None,
)
DEFAULT_CREDENTIALS = [
@ -104,6 +134,10 @@ DEFAULT_CREDENTIALS = [
jina_credentials,
unreal_credentials,
open_router_credentials,
fal_credentials,
exa_credentials,
e2b_credentials,
nvidia_credentials,
]
@ -155,6 +189,14 @@ class IntegrationCredentialsStore:
all_credentials.append(unreal_credentials)
if settings.secrets.open_router_api_key:
all_credentials.append(open_router_credentials)
if settings.secrets.fal_api_key:
all_credentials.append(fal_credentials)
if settings.secrets.exa_api_key:
all_credentials.append(exa_credentials)
if settings.secrets.e2b_api_key:
all_credentials.append(e2b_credentials)
if settings.secrets.nvidia_api_key:
all_credentials.append(nvidia_credentials)
return all_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
@ -210,18 +252,24 @@ class IntegrationCredentialsStore:
]
self._set_user_integration_creds(user_id, filtered_credentials)
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
def store_state_token(
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
) -> tuple[str, str]:
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
(code_challenge, code_verifier) = self._generate_code_challenge()
state = OAuthState(
token=token,
provider=provider,
code_verifier=code_verifier,
expires_at=int(expires_at.timestamp()),
scopes=scopes,
)
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
@ -231,39 +279,21 @@ class IntegrationCredentialsStore:
user_id=user_id, data=user_integrations
)
return token
return token, code_challenge
def get_any_valid_scopes_from_state_token(
def _generate_code_challenge(self) -> tuple[str, str]:
"""
Generate code challenge using SHA256 from the code verifier.
Currently only SHA256 is supported.(In future if we want to support more methods we can add them here)
"""
code_verifier = secrets.token_urlsafe(128)
sha256_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(sha256_hash).decode("utf-8")
return code_challenge.replace("=", ""), code_verifier
def verify_state_token(
self, user_id: str, token: str, provider: str
) -> list[str]:
"""
Get the valid scopes from the OAuth state token. This will return any valid scopes
from any OAuth state token for the given provider. If no valid scopes are found,
an empty list is returned. DO NOT RELY ON THIS TOKEN TO AUTHENTICATE A USER, AS IT
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
THE CODE FOR TOKENS.
"""
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state.token == token
and state.provider == provider
and state.expires_at > now.timestamp()
),
None,
)
if valid_state:
return valid_state.scopes
return []
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
) -> Optional[OAuthState]:
with self.locked_user_integrations(user_id):
user_integrations = self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
@ -285,9 +315,9 @@ class IntegrationCredentialsStore:
oauth_states.remove(valid_state)
user_integrations.oauth_states = oauth_states
self.db_manager.update_user_integrations(user_id, user_integrations)
return True
return valid_state
return False
return None
def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials]

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
from .github import GitHubOAuthHandler
from .google import GoogleOAuthHandler
from .notion import NotionOAuthHandler
from .twitter import TwitterOAuthHandler
if TYPE_CHECKING:
from ..providers import ProviderName
@ -15,6 +16,7 @@ HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
GitHubOAuthHandler,
GoogleOAuthHandler,
NotionOAuthHandler,
TwitterOAuthHandler,
]
}
# --8<-- [end:HANDLERS_BY_NAMEExample]

View File

@ -1,7 +1,7 @@
import logging
import time
from abc import ABC, abstractmethod
from typing import ClassVar
from typing import ClassVar, Optional
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
@ -23,7 +23,9 @@ class BaseOAuthHandler(ABC):
@abstractmethod
# --8<-- [start:BaseOAuthHandler3]
def get_login_url(self, scopes: list[str], state: str) -> str:
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
# --8<-- [end:BaseOAuthHandler3]
"""Constructs a login URL that the user can be redirected to"""
...
@ -31,7 +33,7 @@ class BaseOAuthHandler(ABC):
@abstractmethod
# --8<-- [start:BaseOAuthHandler4]
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
# --8<-- [end:BaseOAuthHandler4]
"""Exchanges the acquired authorization code from login for a set of tokens"""

View File

@ -34,7 +34,9 @@ class GitHubOAuthHandler(BaseOAuthHandler):
self.token_url = "https://github.com/login/oauth/access_token"
self.revoke_url = "https://api.github.com/applications/{client_id}/token"
def get_login_url(self, scopes: list[str], state: str) -> str:
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@ -44,7 +46,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
return f"{self.auth_base_url}?{urlencode(params)}"
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})

View File

@ -1,4 +1,5 @@
import logging
from typing import Optional
from google.auth.external_account_authorized_user import (
Credentials as ExternalAccountCredentials,
@ -38,7 +39,9 @@ class GoogleOAuthHandler(BaseOAuthHandler):
self.token_uri = "https://oauth2.googleapis.com/token"
self.revoke_uri = "https://oauth2.googleapis.com/revoke"
def get_login_url(self, scopes: list[str], state: str) -> str:
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
logger.debug(f"Setting up OAuth flow with scopes: {all_scopes}")
flow = self._setup_oauth_flow(all_scopes)
@ -52,7 +55,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
return authorization_url
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
logger.debug(f"Exchanging code for tokens with scopes: {scopes}")

View File

@ -1,4 +1,5 @@
from base64 import b64encode
from typing import Optional
from urllib.parse import urlencode
from backend.data.model import OAuth2Credentials
@ -26,7 +27,9 @@ class NotionOAuthHandler(BaseOAuthHandler):
self.auth_base_url = "https://api.notion.com/v1/oauth/authorize"
self.token_url = "https://api.notion.com/v1/oauth/token"
def get_login_url(self, scopes: list[str], state: str) -> str:
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@ -37,7 +40,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
return f"{self.auth_base_url}?{urlencode(params)}"
def exchange_code_for_tokens(
self, code: str, scopes: list[str]
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
request_body = {
"grant_type": "authorization_code",

View File

@ -0,0 +1,171 @@
import time
import urllib.parse
from typing import ClassVar, Optional
import requests
from backend.data.model import OAuth2Credentials, ProviderName
from backend.integrations.oauth.base import BaseOAuthHandler
class TwitterOAuthHandler(BaseOAuthHandler):
PROVIDER_NAME = ProviderName.TWITTER
DEFAULT_SCOPES: ClassVar[list[str]] = [
"tweet.read",
"tweet.write",
"tweet.moderate.write",
"users.read",
"follows.read",
"follows.write",
"offline.access",
"space.read",
"mute.read",
"mute.write",
"like.read",
"like.write",
"list.read",
"list.write",
"block.read",
"block.write",
"bookmark.read",
"bookmark.write",
]
AUTHORIZE_URL = "https://twitter.com/i/oauth2/authorize"
TOKEN_URL = "https://api.x.com/2/oauth2/token"
USERNAME_URL = "https://api.x.com/2/users/me"
REVOKE_URL = "https://api.x.com/2/oauth2/revoke"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_login_url(
self, scopes: list[str], state: str, code_challenge: Optional[str]
) -> str:
"""Generate Twitter OAuth 2.0 authorization URL"""
# scopes = self.handle_default_scopes(scopes)
if code_challenge is None:
raise ValueError("code_challenge is required for Twitter OAuth")
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.DEFAULT_SCOPES),
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
def exchange_code_for_tokens(
self, code: str, scopes: list[str], code_verifier: Optional[str]
) -> OAuth2Credentials:
"""Exchange authorization code for access tokens"""
headers = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"code": code,
"grant_type": "authorization_code",
"redirect_uri": self.redirect_uri,
"code_verifier": code_verifier,
}
auth = (self.client_id, self.client_secret)
response = requests.post(self.TOKEN_URL, headers=headers, data=data, auth=auth)
response.raise_for_status()
tokens = response.json()
username = self._get_username(tokens["access_token"])
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens.get("refresh_token"),
access_token_expires_at=int(time.time()) + tokens["expires_in"],
refresh_token_expires_at=None,
scopes=scopes,
)
def _get_username(self, access_token: str) -> str:
"""Get the username from the access token"""
headers = {"Authorization": f"Bearer {access_token}"}
params = {"user.fields": "username"}
response = requests.get(
f"{self.USERNAME_URL}?{urllib.parse.urlencode(params)}", headers=headers
)
response.raise_for_status()
return response.json()["data"]["username"]
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
"""Refresh access tokens using refresh token"""
if not credentials.refresh_token:
raise ValueError("No refresh token available")
header = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
}
auth = (self.client_id, self.client_secret)
response = requests.post(self.TOKEN_URL, headers=header, data=data, auth=auth)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
print("HTTP Error:", e)
print("Response Content:", response.text)
raise
tokens = response.json()
username = self._get_username(tokens["access_token"])
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=None,
username=username,
access_token=tokens["access_token"],
refresh_token=tokens["refresh_token"],
access_token_expires_at=int(time.time()) + tokens["expires_in"],
scopes=credentials.scopes,
refresh_token_expires_at=None,
)
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
"""Revoke the access token"""
header = {"Content-Type": "application/x-www-form-urlencoded"}
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
}
auth = (self.client_id, self.client_secret)
response = requests.post(self.REVOKE_URL, headers=header, data=data, auth=auth)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
print("HTTP Error:", e)
print("Response Content:", response.text)
raise
return response.status_code == 200

View File

@ -19,6 +19,7 @@ class ProviderName(str, Enum):
JINA = "jina"
MEDIUM = "medium"
NOTION = "notion"
NVIDIA = "nvidia"
OLLAMA = "ollama"
OPENAI = "openai"
OPENWEATHERMAP = "openweathermap"
@ -27,5 +28,6 @@ class ProviderName(str, Enum):
REPLICATE = "replicate"
REVID = "revid"
SLANT3D = "slant3d"
TWITTER = "twitter"
UNREAL_SPEECH = "unreal_speech"
# --8<-- [end:ProviderName]

View File

@ -7,6 +7,6 @@ app_config = Config()
# TODO: add test to assert this matches the actual API route
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
return (
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
f"/webhooks/{webhook_id}/ingress"
)

View File

@ -60,11 +60,12 @@ def login(
requested_scopes = scopes.split(",") if scopes else []
# Generate and store a secure random state token along with the scopes
state_token = creds_manager.store.store_state_token(
state_token, code_challenge = creds_manager.store.store_state_token(
user_id, provider, requested_scopes
)
login_url = handler.get_login_url(requested_scopes, state_token)
login_url = handler.get_login_url(
requested_scopes, state_token, code_challenge=code_challenge
)
return LoginResponse(login_url=login_url, state_token=state_token)
@ -92,19 +93,21 @@ def callback(
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not creds_manager.store.verify_state_token(user_id, state_token, provider):
valid_state = creds_manager.store.verify_state_token(user_id, state_token, provider)
if not valid_state:
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
scopes = creds_manager.store.get_any_valid_scopes_from_state_token(
user_id, state_token, provider
)
scopes = valid_state.scopes
logger.debug(f"Retrieved scopes from state token: {scopes}")
scopes = handler.handle_default_scopes(scopes)
credentials = handler.exchange_code_for_tokens(code, scopes)
credentials = handler.exchange_code_for_tokens(
code, scopes, valid_state.code_verifier
)
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
# Check if the granted scopes are sufficient for the requested scopes
@ -317,7 +320,8 @@ async def webhook_ingress_generic(
continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution(
node.graph_id,
graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id,
)

View File

@ -58,6 +58,21 @@ class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: List[APIKeyPermission]
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)
class RequestTopUp(pydantic.BaseModel):
amount: int
"""Amount of credits to top up."""

View File

@ -2,6 +2,7 @@ import contextlib
import logging
import typing
import autogpt_libs.auth.models
import fastapi
import fastapi.responses
import starlette.middleware.cors
@ -16,7 +17,9 @@ import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.server.v2.library.model
import backend.server.v2.library.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes
import backend.util.service
import backend.util.settings
@ -117,9 +120,24 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod
async def test_execute_graph(
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
graph_id: str,
graph_version: int,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
return backend.server.routers.v1.execute_graph(
graph_id, graph_version, node_input, user_id
)
@staticmethod
async def test_get_graph(
graph_id: str,
graph_version: int,
user_id: str,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version
)
@staticmethod
async def test_create_graph(
@ -149,5 +167,71 @@ class AgentServer(backend.util.service.AppProcess):
async def test_delete_graph(graph_id: str, user_id: str):
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
@staticmethod
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
return await backend.server.v2.library.routes.presets.get_presets(
user_id=user_id, page=page, page_size=page_size
)
@staticmethod
async def test_get_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.get_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.create_preset(
preset=preset, user_id=user_id
)
@staticmethod
async def test_update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.update_preset(
preset_id=preset_id, preset=preset, user_id=user_id
)
@staticmethod
async def test_delete_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.delete_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input,
user_id=user_id,
)
@staticmethod
async def test_create_store_listing(
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
):
return await backend.server.v2.store.routes.create_submission(request, user_id)
@staticmethod
async def test_review_store_listing(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: autogpt_libs.auth.models.User,
):
return await backend.server.v2.store.routes.review_submission(request, user)
def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides)

View File

@ -14,6 +14,7 @@ from typing_extensions import Optional, TypedDict
import backend.data.block
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.v2.library.db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.api_key import (
@ -229,11 +230,6 @@ async def get_graph(
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
@v1_router.get(
path="/templates/{graph_id}/versions",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
@ -249,12 +245,11 @@ async def get_graph_all_versions(
async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
return await do_create_graph(create_graph, user_id=user_id)
async def do_create_graph(
create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint
user_id: str,
@ -266,7 +261,6 @@ async def do_create_graph(
graph = await graph_db.get_graph(
create_graph.template_id,
create_graph.template_version,
template=True,
user_id=user_id,
)
if not graph:
@ -274,13 +268,18 @@ async def do_create_graph(
400, detail=f"Template #{create_graph.template_id} not found"
)
graph.version = 1
# Create a library agent for the new graph
await backend.server.v2.library.db.create_library_agent(
graph.id,
graph.version,
user_id,
)
else:
raise HTTPException(
status_code=400, detail="Either graph or template_id must be provided."
)
graph.is_template = is_template
graph.is_active = not is_template
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph = await graph_db.create_graph(graph, user_id=user_id)
@ -310,11 +309,6 @@ async def delete_graph(
@v1_router.put(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
)
@v1_router.put(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def update_graph(
graph_id: str,
graph: graph_db.Graph,
@ -347,6 +341,11 @@ async def update_graph(
if new_graph_version.is_active:
# Keep the library agent up to date with the new active version
await backend.server.v2.library.db.update_agent_version_in_library(
user_id, graph.id, graph.version
)
def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id)
@ -402,6 +401,12 @@ async def set_graph_active_version(
version=new_active_version,
user_id=user_id,
)
# Keep the library agent up to date with the new active version
await backend.server.v2.library.db.update_agent_version_in_library(
user_id, new_active_graph.id, new_active_graph.version
)
if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version
await on_graph_deactivate(
@ -417,12 +422,13 @@ async def set_graph_active_version(
)
def execute_graph(
graph_id: str,
graph_version: int,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type
try:
graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id
graph_id, node_input, user_id=user_id, graph_version=graph_version
)
return {"id": graph_exec.graph_exec_id}
except Exception as e:
@ -477,47 +483,6 @@ async def get_graph_run_node_execution_results(
return await execution_db.get_execution_results(graph_exec_id)
########################################################
##################### Templates ########################
########################################################
@v1_router.get(
path="/templates",
tags=["graphs", "templates"],
dependencies=[Depends(auth_middleware)],
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
@v1_router.get(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
@v1_router.post(
path="/templates",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
########################################################
##################### Schedules ########################
########################################################
@ -590,7 +555,7 @@ def get_execution_schedules(
@v1_router.post(
"/api-keys",
response_model=list[CreateAPIKeyResponse] | dict[str, str],
response_model=CreateAPIKeyResponse,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@ -632,7 +597,7 @@ async def get_api_keys(
@v1_router.get(
"/api-keys/{key_id}",
response_model=list[APIKeyWithoutHash] | dict[str, str],
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@ -653,7 +618,7 @@ async def get_api_key(
@v1_router.delete(
"/api-keys/{key_id}",
response_model=list[APIKeyWithoutHash] | dict[str, str],
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@ -675,7 +640,7 @@ async def delete_api_key(
@v1_router.post(
"/api-keys/{key_id}/suspend",
response_model=list[APIKeyWithoutHash] | dict[str, str],
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)
@ -697,7 +662,7 @@ async def suspend_key(
@v1_router.put(
"/api-keys/{key_id}/permissions",
response_model=list[APIKeyWithoutHash] | dict[str, str],
response_model=APIKeyWithoutHash,
tags=["api-keys"],
dependencies=[Depends(auth_middleware)],
)

View File

@ -1,5 +1,5 @@
import json
import logging
from typing import List
import prisma.errors
import prisma.models
@ -7,6 +7,7 @@ import prisma.types
import backend.data.graph
import backend.data.includes
import backend.server.model
import backend.server.v2.library.model
import backend.server.v2.store.exceptions
@ -14,90 +15,152 @@ logger = logging.getLogger(__name__)
async def get_library_agents(
user_id: str,
) -> List[backend.server.v2.library.model.LibraryAgent]:
"""
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
"""
logger.debug(f"Getting library agents for user {user_id}")
user_id: str, search_query: str | None = None
) -> list[backend.server.v2.library.model.LibraryAgent]:
logger.debug(
f"Fetching library agents for user_id={user_id} search_query={search_query}"
)
try:
# Get agents created by user with nodes and links
user_created = await prisma.models.AgentGraph.prisma().find_many(
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
if search_query and len(search_query.strip()) > 100:
logger.warning(f"Search query too long: {search_query}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Search query is too long."
)
# Get agents in user's library with nodes and links
library_agents = await prisma.models.UserAgent.prisma().find_many(
where=prisma.types.UserAgentWhereInput(
userId=user_id, isDeleted=False, isArchived=False
),
include={
where_clause = prisma.types.LibraryAgentWhereInput(
userId=user_id,
isDeleted=False,
isArchived=False,
)
if search_query:
where_clause["OR"] = [
{
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
"is": {"name": {"contains": search_query, "mode": "insensitive"}}
}
},
{
"Agent": {
"is": {
"description": {"contains": search_query, "mode": "insensitive"}
}
}
},
]
try:
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include={
"Agent": {
"include": {
"AgentNodes": {"include": {"Input": True, "Output": True}}
}
}
},
order=[{"updatedAt": "desc"}],
)
logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.")
return [
backend.server.v2.library.model.LibraryAgent.from_db(agent)
for agent in library_agents
]
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agents: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Unable to fetch library agents."
)
# Convert to Graph models first
graphs = []
# Add user created agents
for agent in user_created:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent))
except Exception as e:
logger.error(f"Error processing user created agent {agent.id}: {e}")
continue
async def create_library_agent(
agent_id: str, agent_version: int, user_id: str
) -> prisma.models.LibraryAgent:
"""
Adds an agent to the user's library (LibraryAgent table)
"""
# Add library agents
for agent in library_agents:
if agent.Agent:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
except Exception as e:
logger.error(f"Error processing library agent {agent.agentId}: {e}")
continue
try:
# Convert Graph models to LibraryAgent models
result = []
for graph in graphs:
result.append(
backend.server.v2.library.model.LibraryAgent(
id=graph.id,
version=graph.version,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
isCreatedByUser=any(a.id == graph.id for a in user_created),
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
library_agent = await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
userId=user_id,
agentId=agent_id,
agentVersion=agent_version,
isCreatedByUser=False,
useGraphIsActiveVersion=True,
)
logger.debug(f"Found {len(result)} library agents")
return result
)
return library_agent
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting library agents: {str(e)}")
logger.error(f"Database error creating agent to library: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch library agents"
"Failed to create agent to library"
) from e
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
async def update_agent_version_in_library(
user_id: str, agent_id: str, agent_version: int
) -> None:
"""
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
Updates the agent version in the library
"""
try:
await prisma.models.LibraryAgent.prisma().update(
where={
"userId": user_id,
"agentId": agent_id,
"useGraphIsActiveVersion": True,
},
data=prisma.types.LibraryAgentUpdateInput(
Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput(
connect=prisma.types.AgentGraphWhereUniqueInput(
id=agent_id,
version=agent_version,
),
),
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update agent version in library"
) from e
async def update_library_agent(
library_agent_id: str,
user_id: str,
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> None:
"""
Updates the library agent with the given fields
"""
try:
await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent_id, "userId": user_id},
data=prisma.types.LibraryAgentUpdateInput(
useGraphIsActiveVersion=auto_update_version,
isFavorite=is_favorite,
isArchived=is_archived,
isDeleted=is_deleted,
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update library agent"
) from e
async def add_store_agent_to_library(
store_listing_version_id: str, user_id: str
) -> None:
"""
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
if they don't already have it
"""
logger.debug(
@ -131,7 +194,7 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
)
# Check if user already has this agent
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
where={
"userId": user_id,
"agentId": agent.id,
@ -145,9 +208,9 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
)
return
# Create UserAgent entry
await prisma.models.UserAgent.prisma().create(
data=prisma.types.UserAgentCreateInput(
# Create LibraryAgent entry
await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.LibraryAgentCreateInput(
userId=user_id,
agentId=agent.id,
agentVersion=agent.version,
@ -163,3 +226,116 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to add agent to library"
) from e
##############################################
########### Presets DB Functions #############
##############################################
async def get_presets(
user_id: str, page: int, page_size: int
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
try:
presets = await prisma.models.AgentPreset.prisma().find_many(
where={"userId": user_id},
skip=page * page_size,
take=page_size,
)
total_items = await prisma.models.AgentPreset.prisma().count(
where={"userId": user_id},
)
total_pages = (total_items + page_size - 1) // page_size
presets = [
backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
for preset in presets
]
return backend.server.v2.library.model.LibraryAgentPresetResponse(
presets=presets,
pagination=backend.server.model.Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting presets: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch presets"
) from e
async def get_preset(
user_id: str, preset_id: str
) -> backend.server.v2.library.model.LibraryAgentPreset | None:
try:
preset = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id, "userId": user_id}, include={"InputPresets": True}
)
if not preset:
return None
return backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting preset: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch preset"
) from e
async def create_or_update_preset(
user_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
preset_id: str | None = None,
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
new_preset = await prisma.models.AgentPreset.prisma().upsert(
where={
"id": preset_id if preset_id else "",
},
data={
"create": {
"userId": user_id,
"name": preset.name,
"description": preset.description,
"agentId": preset.agent_id,
"agentVersion": preset.agent_version,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": json.dumps(data)}
for name, data in preset.inputs.items()
]
},
},
"update": {
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
},
},
)
return backend.server.v2.library.model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating preset: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create preset"
) from e
async def delete_preset(user_id: str, preset_id: str) -> None:
try:
await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id, "userId": user_id},
data={"isDeleted": True},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to delete preset"
) from e

View File

@ -37,7 +37,7 @@ async def test_get_library_agents(mocker):
]
mock_library_agents = [
prisma.models.UserAgent(
prisma.models.LibraryAgent(
id="ua1",
userId="test-user",
agentId="agent2",
@ -48,6 +48,7 @@ async def test_get_library_agents(mocker):
createdAt=datetime.now(),
updatedAt=datetime.now(),
isFavorite=False,
useGraphIsActiveVersion=True,
Agent=prisma.models.AgentGraph(
id="agent2",
version=1,
@ -67,8 +68,8 @@ async def test_get_library_agents(mocker):
return_value=mock_user_created
)
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
mock_user_agent.return_value.find_many = mocker.AsyncMock(
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents
)
@ -76,40 +77,16 @@ async def test_get_library_agents(mocker):
result = await db.get_library_agents("test-user")
# Verify results
assert len(result) == 2
assert result[0].id == "agent1"
assert result[0].name == "Test Agent 1"
assert result[0].description == "Test Description 1"
assert result[0].isCreatedByUser is True
assert result[1].id == "agent2"
assert result[1].name == "Test Agent 2"
assert result[1].description == "Test Description 2"
assert result[1].isCreatedByUser is False
# Verify mocks called correctly
mock_agent_graph.return_value.find_many.assert_called_once_with(
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
)
mock_user_agent.return_value.find_many.assert_called_once_with(
where=prisma.types.UserAgentWhereInput(
userId="test-user", isDeleted=False, isArchived=False
),
include={
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
}
}
},
)
assert len(result) == 1
assert result[0].id == "ua1"
assert result[0].name == "Test Agent 2"
assert result[0].description == "Test Description 2"
assert result[0].is_created_by_user is False
assert result[0].is_latest_version is True
assert result[0].is_favorite is False
assert result[0].agent_id == "agent2"
assert result[0].agent_version == 1
assert result[0].preset_id is None
@pytest.mark.asyncio
@ -152,26 +129,26 @@ async def test_add_agent_to_library(mocker):
return_value=mock_store_listing
)
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_user_agent.return_value.create = mocker.AsyncMock()
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock()
# Call function
await db.add_agent_to_library("version123", "test-user")
await db.add_store_agent_to_library("version123", "test-user")
# Verify mocks called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True}
)
mock_user_agent.return_value.find_first.assert_called_once_with(
mock_library_agent.return_value.find_first.assert_called_once_with(
where={
"userId": "test-user",
"agentId": "agent1",
"agentVersion": 1,
}
)
mock_user_agent.return_value.create.assert_called_once_with(
data=prisma.types.UserAgentCreateInput(
mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
)
)
@ -189,7 +166,7 @@ async def test_add_agent_to_library_not_found(mocker):
# Call function and verify exception
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
await db.add_agent_to_library("version123", "test-user")
await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with(

View File

@ -1,16 +1,112 @@
import datetime
import json
import typing
import prisma.models
import pydantic
import backend.data.block
import backend.data.graph
import backend.server.model
class LibraryAgent(pydantic.BaseModel):
id: str # Changed from agent_id to match GraphMeta
version: int # Changed from agent_version to match GraphMeta
is_active: bool # Added to match GraphMeta
agent_id: str
agent_version: int # Changed from agent_version to match GraphMeta
preset_id: str | None
updated_at: datetime.datetime
name: str
description: str
isCreatedByUser: bool
# Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
is_favorite: bool
is_created_by_user: bool
is_latest_version: bool
@staticmethod
def from_db(agent: prisma.models.LibraryAgent):
if not agent.Agent:
raise ValueError("AgentGraph is required")
graph = backend.data.graph.GraphModel.from_db(agent.Agent)
agent_updated_at = agent.Agent.updatedAt
lib_agent_updated_at = agent.updatedAt
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
else lib_agent_updated_at
)
return LibraryAgent(
id=agent.id,
agent_id=agent.agentId,
agent_version=agent.agentVersion,
updated_at=updated_at,
name=graph.name,
description=graph.description,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
is_favorite=agent.isFavorite,
is_created_by_user=agent.isCreatedByUser,
is_latest_version=graph.is_active,
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
)
class LibraryAgentPreset(pydantic.BaseModel):
id: str
updated_at: datetime.datetime
agent_id: str
agent_version: int
name: str
description: str
is_active: bool
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
@staticmethod
def from_db(preset: prisma.models.AgentPreset):
input_data = {}
for data in preset.InputPresets or []:
input_data[data.name] = json.loads(data.data)
return LibraryAgentPreset(
id=preset.id,
updated_at=preset.updatedAt,
agent_id=preset.agentId,
agent_version=preset.agentVersion,
name=preset.name,
description=preset.description,
is_active=preset.isActive,
inputs=input_data,
)
class LibraryAgentPresetResponse(pydantic.BaseModel):
presets: list[LibraryAgentPreset]
pagination: backend.server.model.Pagination
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
name: str
description: str
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
agent_id: str
agent_version: int
is_active: bool

View File

@ -1,23 +1,35 @@
import datetime
import prisma.models
import backend.data.block
import backend.server.model
import backend.server.v2.library.model
def test_library_agent():
agent = backend.server.v2.library.model.LibraryAgent(
id="test-agent-123",
version=1,
is_active=True,
agent_id="agent-123",
agent_version=1,
preset_id=None,
updated_at=datetime.datetime.now(),
name="Test Agent",
description="Test description",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
)
assert agent.id == "test-agent-123"
assert agent.version == 1
assert agent.is_active is True
assert agent.agent_id == "agent-123"
assert agent.agent_version == 1
assert agent.name == "Test Agent"
assert agent.description == "Test description"
assert agent.isCreatedByUser is False
assert agent.is_favorite is False
assert agent.is_created_by_user is False
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
@ -25,19 +37,148 @@ def test_library_agent():
def test_library_agent_with_user_created():
agent = backend.server.v2.library.model.LibraryAgent(
id="user-agent-456",
version=2,
is_active=True,
agent_id="agent-456",
agent_version=2,
preset_id=None,
updated_at=datetime.datetime.now(),
name="User Created Agent",
description="An agent created by the user",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
)
assert agent.id == "user-agent-456"
assert agent.version == 2
assert agent.is_active is True
assert agent.agent_id == "agent-456"
assert agent.agent_version == 2
assert agent.name == "User Created Agent"
assert agent.description == "An agent created by the user"
assert agent.isCreatedByUser is True
assert agent.is_favorite is False
assert agent.is_created_by_user is True
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}}
def test_library_agent_preset():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"input1": backend.data.block.BlockInput(
name="input1",
data={"type": "string", "value": "test value"},
)
},
updated_at=datetime.datetime.now(),
)
assert preset.id == "preset-123"
assert preset.name == "Test Preset"
assert preset.description == "Test preset description"
assert preset.agent_id == "test-agent-123"
assert preset.agent_version == 1
assert preset.is_active is True
assert preset.inputs == {
"input1": backend.data.block.BlockInput(
name="input1", data={"type": "string", "value": "test value"}
)
}
def test_library_agent_preset_response():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"input1": backend.data.block.BlockInput(
name="input1",
data={"type": "string", "value": "test value"},
)
},
updated_at=datetime.datetime.now(),
)
pagination = backend.server.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=10
)
response = backend.server.v2.library.model.LibraryAgentPresetResponse(
presets=[preset], pagination=pagination
)
assert len(response.presets) == 1
assert response.presets[0].id == "preset-123"
assert response.pagination.total_items == 1
assert response.pagination.total_pages == 1
assert response.pagination.current_page == 1
assert response.pagination.page_size == 10
def test_create_library_agent_preset_request():
request = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="New Preset",
description="New preset description",
agent_id="agent-123",
agent_version=1,
is_active=True,
inputs={
"input1": backend.data.block.BlockInput(
name="input1",
data={"type": "string", "value": "test value"},
)
},
)
assert request.name == "New Preset"
assert request.description == "New preset description"
assert request.agent_id == "agent-123"
assert request.agent_version == 1
assert request.is_active is True
assert request.inputs == {
"input1": backend.data.block.BlockInput(
name="input1", data={"type": "string", "value": "test value"}
)
}
def test_library_agent_from_db():
# Create mock DB agent
db_agent = prisma.models.AgentPreset(
id="test-agent-123",
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
agentId="agent-123",
agentVersion=1,
name="Test Agent",
description="Test agent description",
isActive=True,
userId="test-user-123",
isDeleted=False,
InputPresets=[
prisma.models.AgentNodeExecutionInputOutput(
id="input-123",
time=datetime.datetime.now(),
name="input1",
data='{"type": "string", "value": "test value"}',
)
],
)
# Convert to LibraryAgentPreset
agent = backend.server.v2.library.model.LibraryAgentPreset.from_db(db_agent)
assert agent.id == "test-agent-123"
assert agent.agent_version == 1
assert agent.is_active is True
assert agent.name == "Test Agent"
assert agent.description == "Test agent description"
assert agent.inputs == {"input1": {"type": "string", "value": "test value"}}

View File

@ -1,123 +0,0 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import prisma
import backend.data.graph
import backend.integrations.creds_manager
import backend.integrations.webhooks.graph_lifecycle_hooks
import backend.server.v2.library.db
import backend.server.v2.library.model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
integration_creds_manager = (
backend.integrations.creds_manager.IntegrationCredentialsManager()
)
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_library_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await backend.server.v2.library.db.get_library_agents(user_id)
return agents
except Exception:
logger.exception("Exception occurred whilst getting library agents")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
# Get the graph from the store listing
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
agent = store_listing_version.Agent
if agent.userId == user_id:
raise fastapi.HTTPException(
status_code=400, detail="Cannot add own agent to library"
)
# Create a new graph from the template
graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True, user_id=user_id
)
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent {agent.id} not found"
)
# Create a deep copy with new IDs
graph.version = 1
graph.is_template = False
graph.is_active = True
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
# Save the new graph
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
graph = (
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
)
return fastapi.Response(status_code=201)
except Exception:
logger.exception("Exception occurred whilst adding agent to library")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)

View File

@ -0,0 +1,9 @@
import fastapi
from .agents import router as agents_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(agents_router)

View File

@ -0,0 +1,148 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import autogpt_libs.utils.cache
import fastapi
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.store.exceptions
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_library_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await backend.server.v2.library.db.get_library_agents(user_id)
return agents
except Exception as e:
logger.exception(f"Exception occurred whilst getting library agents: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
# Use the database function to add the agent to the library
await backend.server.v2.library.db.add_store_agent_to_library(
store_listing_version_id, user_id
)
return fastapi.Response(status_code=201)
except backend.server.v2.store.exceptions.AgentNotFoundError:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
except backend.server.v2.store.exceptions.DatabaseError as e:
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
except Exception as e:
logger.exception(
f"Unexpected exception occurred whilst adding agent to library: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
@router.put(
"/agents/{library_agent_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=204,
)
async def update_library_agent(
library_agent_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> fastapi.Response:
"""
Update the library agent with the given fields.
Args:
library_agent_id (str): ID of the library agent to update
user_id (str): ID of the authenticated user
auto_update_version (bool): Whether to auto-update the agent version
is_favorite (bool): Whether the agent is marked as favorite
is_archived (bool): Whether the agent is archived
is_deleted (bool): Whether the agent is deleted
Returns:
fastapi.Response: 204 status code on success
Raises:
HTTPException: If there is an error updating the library agent
"""
try:
# Use the database function to update the library agent
await backend.server.v2.library.db.update_library_agent(
library_agent_id,
user_id,
auto_update_version,
is_favorite,
is_archived,
is_deleted,
)
return fastapi.Response(status_code=204)
except backend.server.v2.store.exceptions.DatabaseError as e:
logger.exception(f"Database error occurred whilst updating library agent: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)
except Exception as e:
logger.exception(
f"Unexpected exception occurred whilst updating library agent: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)

View File

@ -0,0 +1,156 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import autogpt_libs.utils.cache
import fastapi
import backend.data.graph
import backend.executor
import backend.integrations.creds_manager
import backend.integrations.webhooks.graph_lifecycle_hooks
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.util.service
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
integration_creds_manager = (
backend.integrations.creds_manager.IntegrationCredentialsManager()
)
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get("/presets")
async def get_presets(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
page: int = 1,
page_size: int = 10,
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
try:
presets = await backend.server.v2.library.db.get_presets(
user_id, page, page_size
)
return presets
except Exception as e:
logger.exception(f"Exception occurred whilst getting presets: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get presets")
@router.get("/presets/{preset_id}")
async def get_preset(
preset_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
if not preset:
raise fastapi.HTTPException(
status_code=404,
detail=f"Preset {preset_id} not found",
)
return preset
except Exception as e:
logger.exception(f"Exception occurred whilst getting preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get preset")
@router.post("/presets")
async def create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
return await backend.server.v2.library.db.create_or_update_preset(
user_id, preset
)
except Exception as e:
logger.exception(f"Exception occurred whilst creating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to create preset")
@router.put("/presets/{preset_id}")
async def update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
return await backend.server.v2.library.db.create_or_update_preset(
user_id, preset, preset_id
)
except Exception as e:
logger.exception(f"Exception occurred whilst updating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to update preset")
@router.delete("/presets/{preset_id}")
async def delete_preset(
preset_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
):
try:
await backend.server.v2.library.db.delete_preset(user_id, preset_id)
return fastapi.Response(status_code=204)
except Exception as e:
logger.exception(f"Exception occurred whilst deleting preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to delete preset")
@router.post(
path="/presets/{preset_id}/execute",
tags=["presets"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: dict[typing.Any, typing.Any],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> dict[str, typing.Any]: # FIXME: add proper return type
try:
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
if not preset:
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
logger.info(f"Preset inputs: {preset.inputs}")
updated_node_input = node_input.copy()
# Merge in preset input values
for key, value in preset.inputs.items():
if key not in updated_node_input:
updated_node_input[key] = value
execution = execution_manager_client().add_execution(
graph_id=graph_id,
graph_version=graph_version,
data=updated_node_input,
user_id=user_id,
preset_id=preset_id,
)
logger.info(f"Execution added: {execution} with input: {updated_node_input}")
return {"id": execution.graph_exec_id}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise fastapi.HTTPException(status_code=400, detail=msg)

View File

@ -1,3 +1,5 @@
import datetime
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
@ -35,21 +37,29 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [
backend.server.v2.library.model.LibraryAgent(
id="test-agent-1",
version=1,
is_active=True,
agent_id="test-agent-1",
agent_version=1,
preset_id="preset-1",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
name="Test Agent 1",
description="Test Description 1",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
backend.server.v2.library.model.LibraryAgent(
id="test-agent-2",
version=1,
is_active=True,
agent_id="test-agent-2",
agent_version=1,
preset_id="preset-2",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
name="Test Agent 2",
description="Test Description 2",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}},
),
@ -65,10 +75,10 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
for agent in response.json()
]
assert len(data) == 2
assert data[0].id == "test-agent-1"
assert data[0].isCreatedByUser is True
assert data[1].id == "test-agent-2"
assert data[1].isCreatedByUser is False
assert data[0].agent_id == "test-agent-1"
assert data[0].is_created_by_user is True
assert data[1].agent_id == "test-agent-2"
assert data[1].is_created_by_user is False
mock_db_call.assert_called_once_with("test-user-id")

View File

@ -325,7 +325,10 @@ async def get_store_submissions(
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
# Query submissions from database
submissions = await prisma.models.StoreSubmission.prisma().find_many(
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
where=where,
skip=skip,
take=page_size,
order=[{"date_submitted": "desc"}],
)
# Get total count for pagination
@ -405,9 +408,7 @@ async def delete_store_submission(
)
# Delete the submission
await prisma.models.StoreListing.prisma().delete(
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
)
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}"
@ -504,7 +505,15 @@ async def create_store_submission(
"subHeading": sub_heading,
}
},
}
},
include={"StoreListingVersions": True},
)
slv_id = (
listing.StoreListingVersions[0].id
if listing.StoreListingVersions is not None
and len(listing.StoreListingVersions) > 0
else None
)
logger.debug(f"Created store listing for agent {agent_id}")
@ -521,6 +530,7 @@ async def create_store_submission(
status=prisma.enums.SubmissionStatus.PENDING,
runs=0,
rating=0.0,
store_listing_version_id=slv_id,
)
except (
@ -811,9 +821,7 @@ async def get_agent(
agent = store_listing_version.Agent
graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True
)
graph = await backend.data.graph.get_graph(agent.id, agent.version)
if not graph:
raise fastapi.HTTPException(
@ -832,3 +840,74 @@ async def get_agent(
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent"
) from e
async def review_store_submission(
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
) -> prisma.models.StoreListingSubmission:
"""Review a store listing submission."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
include={"StoreListing": True},
)
)
if not store_listing_version or not store_listing_version.StoreListing:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
status = (
prisma.enums.SubmissionStatus.APPROVED
if is_approved
else prisma.enums.SubmissionStatus.REJECTED
)
create_data = prisma.types.StoreListingSubmissionCreateInput(
StoreListingVersion={"connect": {"id": store_listing_version_id}},
Status=status,
reviewComments=comments,
Reviewer={"connect": {"id": reviewer_id}},
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
createdAt=datetime.now(),
updatedAt=datetime.now(),
)
update_data = prisma.types.StoreListingSubmissionUpdateInput(
Status=status,
reviewComments=comments,
Reviewer={"connect": {"id": reviewer_id}},
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
updatedAt=datetime.now(),
)
if is_approved:
await prisma.models.StoreListing.prisma().update(
where={"id": store_listing_version.StoreListing.id},
data={"isApproved": True},
)
submission = await prisma.models.StoreListingSubmission.prisma().upsert(
where={"storeListingVersionId": store_listing_version_id},
data=prisma.types.StoreListingSubmissionUpsertInput(
create=create_data,
update=update_data,
),
)
if not submission:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing submission {store_listing_version_id} not found",
)
return submission
except Exception as e:
logger.error(f"Error reviewing store submission: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to review store submission"
) from e

View File

@ -115,6 +115,7 @@ class StoreSubmission(pydantic.BaseModel):
status: prisma.enums.SubmissionStatus
runs: int
rating: float
store_listing_version_id: str | None = None
class StoreSubmissionsResponse(pydantic.BaseModel):
@ -151,3 +152,9 @@ class StoreReviewCreate(pydantic.BaseModel):
store_listing_version_id: str
score: int
comments: str | None = None
class ReviewSubmissionRequest(pydantic.BaseModel):
store_listing_version_id: str
isApproved: bool
comments: str

View File

@ -642,3 +642,33 @@ async def download_agent_file(
return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json"
)
@router.post(
"/submissions/review/{store_listing_version_id}",
tags=["store", "private"],
)
async def review_submission(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
):
# Proceed with the review submission logic
try:
submission = await backend.server.v2.store.db.review_store_submission(
store_listing_version_id=request.store_listing_version_id,
is_approved=request.isApproved,
comments=request.comments,
reviewer_id=user.user_id,
)
return submission
except Exception:
logger.exception("Exception occurred whilst reviewing store submission")
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while reviewing the store submission"
},
)

View File

@ -253,7 +253,7 @@ async def block_autogen_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(

View File

@ -157,7 +157,7 @@ async def reddit_marketing_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)

View File

@ -8,12 +8,19 @@ from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution
async def create_test_user() -> User:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser#example.com",
"name": "Test User",
}
async def create_test_user(alt_user: bool = False) -> User:
if alt_user:
test_user_data = {
"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1b",
"email": "testuser2#example.com",
"name": "Test User 2",
}
else:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser#example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data)
return user
@ -38,7 +45,7 @@ def create_test_graph() -> graph.Graph:
graph.Node(
block_id=FillTextTemplateBlock().id,
input_default={
"format": "{a}, {b}{c}",
"format": "{{a}}, {{b}}{{c}}",
"values_#_c": "!!!",
},
),
@ -79,7 +86,7 @@ async def sample_agent():
test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)

View File

@ -264,6 +264,10 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
notion_client_secret: str = Field(
default="", description="Notion OAuth client secret"
)
twitter_client_id: str = Field(default="", description="Twitter/X OAuth client ID")
twitter_client_secret: str = Field(
default="", description="Twitter/X OAuth client secret"
)
openai_api_key: str = Field(default="", description="OpenAI API key")
anthropic_api_key: str = Field(default="", description="Anthropic API key")
@ -300,7 +304,10 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
jina_api_key: str = Field(default="", description="Jina API Key")
unreal_speech_api_key: str = Field(default="", description="Unreal Speech API Key")
fal_key: str = Field(default="", description="FAL API key")
fal_api_key: str = Field(default="", description="FAL API key")
exa_api_key: str = Field(default="", description="Exa API key")
e2b_api_key: str = Field(default="", description="E2B API key")
nvidia_api_key: str = Field(default="", description="Nvidia API key")
stripe_api_key: str = Field(default="", description="Stripe API Key")
stripe_webhook_secret: str = Field(default="", description="Stripe Webhook Secret")

View File

@ -1,5 +1,3 @@
import re
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
@ -15,8 +13,5 @@ class TextFormatter:
self.env.globals.clear()
def format_string(self, template_str: str, values=None, **kwargs) -> str:
# For python.format compatibility: replace all {...} with {{..}}.
# But avoid replacing {{...}} to {{{...}}}.
template_str = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", template_str)
template = self.env.from_string(template_str)
return template.render(values or {}, **kwargs)

View File

@ -0,0 +1,86 @@
/*
Warnings:
- You are about replace a single brace string input format for the following blocks:
- AgentOutputBlock
- FillTextTemplateBlock
- AITextGeneratorBlock
- AIStructuredResponseGeneratorBlock
with a double brace format.
- This migration can be slow for a large updated AgentNode tables.
*/
BEGIN;
SET LOCAL statement_timeout = '10min';
WITH to_update AS (
SELECT
"id",
"agentBlockId",
"constantInput"::jsonb AS j
FROM "AgentNode"
WHERE
"agentBlockId" IN (
'363ae599-353e-4804-937e-b2ee3cef3da4', -- AgentOutputBlock
'db7d8f02-2f44-4c55-ab7a-eae0941f0c30', -- FillTextTemplateBlock
'1f292d4a-41a4-4977-9684-7c8d560b9f91', -- AITextGeneratorBlock
'ed55ac19-356e-4243-a6cb-bc599e9b716f' -- AIStructuredResponseGeneratorBlock
)
AND (
"constantInput"::jsonb->>'format' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
OR "constantInput"::jsonb->>'prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
OR "constantInput"::jsonb->>'sys_prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
)
),
updated_rows AS (
SELECT
"id",
"agentBlockId",
(
j
-- Update "format" if it has a single-brace placeholder
|| CASE WHEN j->>'format' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
THEN jsonb_build_object(
'format',
regexp_replace(
j->>'format',
'(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})',
'{{\1}}',
'g'
)
)
ELSE '{}'::jsonb
END
-- Update "prompt" if it has a single-brace placeholder
|| CASE WHEN j->>'prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
THEN jsonb_build_object(
'prompt',
regexp_replace(
j->>'prompt',
'(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})',
'{{\1}}',
'g'
)
)
ELSE '{}'::jsonb
END
-- Update "sys_prompt" if it has a single-brace placeholder
|| CASE WHEN j->>'sys_prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
THEN jsonb_build_object(
'sys_prompt',
regexp_replace(
j->>'sys_prompt',
'(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})',
'{{\1}}',
'g'
)
)
ELSE '{}'::jsonb
END
)::text AS "newConstantInput"
FROM to_update
)
UPDATE "AgentNode" AS an
SET "constantInput" = ur."newConstantInput"
FROM updated_rows ur
WHERE an."id" = ur."id";
COMMIT;

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "AgentPreset" ADD COLUMN "isDeleted" BOOLEAN NOT NULL DEFAULT false;

View File

@ -0,0 +1,46 @@
/*
Warnings:
- You are about to drop the `UserAgent` table. If the table is not empty, all the data it contains will be lost.
*/
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentId_agentVersion_fkey";
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentPresetId_fkey";
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_userId_fkey";
-- DropTable
DROP TABLE "UserAgent";
-- CreateTable
CREATE TABLE "LibraryAgent" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
"agentPresetId" TEXT,
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
"isArchived" BOOLEAN NOT NULL DEFAULT false,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT "LibraryAgent_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LibraryAgent_userId_idx" ON "LibraryAgent"("userId");
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "LibraryAgent" ADD COLUMN "useGraphIsActiveVersion" BOOLEAN NOT NULL DEFAULT false;

View File

@ -0,0 +1,29 @@
/*
Warnings:
- A unique constraint covering the columns `[agentId]` on the table `StoreListing` will be added. If there are existing duplicate values, this will fail.
*/
-- DropIndex
DROP INDEX "StoreListing_agentId_idx";
-- DropIndex
DROP INDEX "StoreListing_isApproved_idx";
-- DropIndex
DROP INDEX "StoreListingVersion_agentId_agentVersion_isApproved_idx";
-- CreateIndex
CREATE INDEX "StoreListing_agentId_owningUserId_idx" ON "StoreListing"("agentId", "owningUserId");
-- CreateIndex
CREATE INDEX "StoreListing_isDeleted_isApproved_idx" ON "StoreListing"("isDeleted", "isApproved");
-- CreateIndex
CREATE INDEX "StoreListing_isDeleted_idx" ON "StoreListing"("isDeleted");
-- CreateIndex
CREATE UNIQUE INDEX "StoreListing_agentId_key" ON "StoreListing"("agentId");
-- CreateIndex
CREATE INDEX "StoreListingVersion_agentId_agentVersion_isDeleted_idx" ON "StoreListingVersion"("agentId", "agentVersion", "isDeleted");

File diff suppressed because it is too large Load Diff

View File

@ -40,8 +40,9 @@ redis = "^5.2.0"
sentry-sdk = "2.19.2"
strenum = "^0.4.9"
stripe = "^11.3.0"
supabase = "^2.10.0"
supabase = "2.11.0"
tenacity = "^9.0.0"
tweepy = "^4.14.0"
uvicorn = { extras = ["standard"], version = "^0.34.0" }
websockets = "^13.1"
youtube-transcript-api = "^0.6.2"

View File

@ -30,7 +30,7 @@ model User {
CreditTransaction CreditTransaction[]
AgentPreset AgentPreset[]
UserAgent UserAgent[]
LibraryAgent LibraryAgent[]
Profile Profile[]
StoreListing StoreListing[]
@ -65,7 +65,7 @@ model AgentGraph {
AgentGraphExecution AgentGraphExecution[]
AgentPreset AgentPreset[]
UserAgent UserAgent[]
LibraryAgent LibraryAgent[]
StoreListing StoreListing[]
StoreListingVersion StoreListingVersion?
@ -103,15 +103,17 @@ model AgentPreset {
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
UserAgents UserAgent[]
LibraryAgents LibraryAgent[]
AgentExecution AgentGraphExecution[]
isDeleted Boolean @default(false)
@@index([userId])
}
// For the library page
// It is a user controlled list of agents, that they will see in there library
model UserAgent {
model LibraryAgent {
id String @id @default(uuid())
createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt
@ -126,6 +128,8 @@ model UserAgent {
agentPresetId String?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
useGraphIsActiveVersion Boolean @default(false)
isFavorite Boolean @default(false)
isCreatedByUser Boolean @default(false)
isArchived Boolean @default(false)
@ -235,7 +239,7 @@ model AgentGraphExecution {
AgentNodeExecutions AgentNodeExecution[]
// Link to User model
// Link to User model -- Executed by this user
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@ -443,6 +447,8 @@ view Creator {
agent_rating Float
agent_runs Int
is_featured Boolean
// Index or unique are not applied to views
}
view StoreAgent {
@ -465,11 +471,7 @@ view StoreAgent {
rating Float
versions String[]
@@unique([creator_username, slug])
@@index([creator_username])
@@index([featured])
@@index([categories])
@@index([storeListingVersionId])
// Index or unique are not applied to views
}
view StoreSubmission {
@ -487,7 +489,7 @@ view StoreSubmission {
agent_id String
agent_version Int
@@index([user_id])
// Index or unique are not applied to views
}
model StoreListing {
@ -510,9 +512,13 @@ model StoreListing {
StoreListingVersions StoreListingVersion[]
StoreListingSubmission StoreListingSubmission[]
@@index([isApproved])
@@index([agentId])
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@unique([agentId])
@@index([agentId, owningUserId])
@@index([owningUserId])
// Used in the view query
@@index([isDeleted, isApproved])
@@index([isDeleted])
}
model StoreListingVersion {
@ -553,7 +559,7 @@ model StoreListingVersion {
StoreListingReview StoreListingReview[]
@@unique([agentId, agentVersion])
@@index([agentId, agentVersion, isApproved])
@@index([agentId, agentVersion, isDeleted])
}
model StoreListingReview {

View File

@ -1,8 +1,11 @@
from typing import Any
from uuid import UUID
import autogpt_libs.auth.models
import fastapi.exceptions
import pytest
import backend.server.v2.store.model
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
from backend.data.block import BlockSchema
from backend.data.graph import Graph, Link, Node
@ -202,3 +205,92 @@ async def test_clean_graph(server: SpinTestServer):
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
)
assert input_node.input_default["value"] == ""
@pytest.mark.asyncio(scope="session")
async def test_access_store_listing_graph(server: SpinTestServer):
"""
Test the access of a store listing graph.
"""
graph = Graph(
id="test_clean_graph",
name="Test Clean Graph",
description="Test graph cleaning",
nodes=[
Node(
id="input_node",
block_id=AgentInputBlock().id,
input_default={
"name": "test_input",
"value": "test value",
"description": "Test input description",
},
),
],
links=[],
)
# Create graph and get model
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=created_graph.id,
agent_version=created_graph.version,
slug="test-slug",
name="Test name",
sub_heading="Test sub heading",
video_url=None,
image_urls=[],
description="Test description",
categories=[],
)
# First we check the graph an not be accessed by a different user
with pytest.raises(fastapi.exceptions.HTTPException) as exc_info:
await server.agent_server.test_get_graph(
created_graph.id,
created_graph.version,
"3e53486c-cf57-477e-ba2a-cb02dc828e1b",
)
assert exc_info.value.status_code == 404
assert "Graph" in str(exc_info.value.detail)
# Now we create a store listing
store_listing = await server.agent_server.test_create_store_listing(
store_submission_request, DEFAULT_USER_ID
)
if isinstance(store_listing, fastapi.responses.JSONResponse):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
else None
)
assert slv_id is not None
admin = autogpt_libs.auth.models.User(
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
role="admin",
email="admin@example.com",
phone_number="1234567890",
)
await server.agent_server.test_review_store_listing(
backend.server.v2.store.model.ReviewSubmissionRequest(
store_listing_version_id=slv_id,
isApproved=True,
comments="Test comments",
),
admin,
)
# Now we check the graph can be accessed by a user that does not own the graph
got_graph = await server.agent_server.test_get_graph(
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None

View File

@ -1,9 +1,13 @@
import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest
from prisma.models import User
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
import backend.server.v2.library.model
import backend.server.v2.store.model
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.server.model import CreateGraph
@ -31,7 +35,7 @@ async def execute_graph(
# --- Test adding new executions --- #
response = await agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id
test_graph.id, test_graph.version, input_data, test_user.id
)
graph_exec_id = response["id"]
logger.info(f"Created execution with ID: {graph_exec_id}")
@ -102,7 +106,7 @@ async def assert_sample_graph_executions(
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!!!"]}
assert exec.input_data == {
"format": "{a}, {b}{c}",
"format": "{{a}}, {{b}}{{c}}",
"values": {"a": "Hello", "b": "World", "c": "!!!"},
"values_#_a": "Hello",
"values_#_b": "World",
@ -287,3 +291,255 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
assert exec_data.status == execution.ExecutionStatus.COMPLETED
assert exec_data.output_data == {"result": [9]}
logger.info("Completed test_static_input_link_on_graph")
@pytest.mark.asyncio(scope="session")
async def test_execute_preset(server: SpinTestServer):
"""
Test executing a preset.
This test ensures that:
1. A preset can be successfully executed
2. The execution results are correct
Args:
server (SpinTestServer): The test server instance.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["World"]}
@pytest.mark.asyncio(scope="session")
async def test_execute_preset_with_clash(server: SpinTestServer):
"""
Test executing a preset with clashing input data.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={"selected_value": "key1"},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["Hello"]}
@pytest.mark.asyncio(scope="session")
async def test_store_listing_graph(server: SpinTestServer):
logger.info("Starting test_agent_execution")
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
slug="test-slug",
name="Test name",
sub_heading="Test sub heading",
video_url=None,
image_urls=[],
description="Test description",
categories=[],
)
store_listing = await server.agent_server.test_create_store_listing(
store_submission_request, test_user.id
)
if isinstance(store_listing, fastapi.responses.JSONResponse):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
else None
)
assert slv_id is not None
admin = autogpt_libs.auth.models.User(
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
role="admin",
email="admin@example.com",
phone_number="1234567890",
)
await server.agent_server.test_review_store_listing(
backend.server.v2.store.model.ReviewSubmissionRequest(
store_listing_version_id=slv_id,
isApproved=True,
comments="Test comments",
),
admin,
)
alt_test_user = await create_test_user(alt_user=True)
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server,
test_graph,
alt_test_user,
data,
4,
)
await assert_sample_graph_executions(
server.agent_server, test_graph, alt_test_user, graph_exec_id
)
logger.info("Completed test_agent_execution")

View File

@ -140,10 +140,10 @@ async def main():
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
for user in users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
for _ in range(num_agents): # Create 1 UserAgent per user
for _ in range(num_agents): # Create 1 LibraryAgent per user
graph = random.choice(agent_graphs)
preset = random.choice(agent_presets)
user_agent = await db.useragent.create(
user_agent = await db.libraryagent.create(
data={
"userId": user.id,
"agentId": graph.id,

View File

@ -24,8 +24,8 @@
],
"dependencies": {
"@faker-js/faker": "^9.3.0",
"@hookform/resolvers": "^3.9.1",
"@next/third-parties": "^15.0.4",
"@hookform/resolvers": "^3.10.0",
"@next/third-parties": "^15.1.3",
"@radix-ui/react-alert-dialog": "^1.1.4",
"@radix-ui/react-avatar": "^1.1.1",
"@radix-ui/react-checkbox": "^1.1.2",
@ -60,25 +60,25 @@
"dotenv": "^16.4.7",
"elliptic": "6.6.1",
"embla-carousel-react": "^8.3.0",
"framer-motion": "^11.15.0",
"framer-motion": "^11.16.0",
"geist": "^1.3.1",
"launchdarkly-react-client-sdk": "^3.6.0",
"lucide-react": "^0.468.0",
"lucide-react": "^0.469.0",
"moment": "^2.30.1",
"next": "^14.2.13",
"next-themes": "^0.4.4",
"react": "^18",
"react-day-picker": "^9.4.4",
"react-day-picker": "^9.5.0",
"react-dom": "^18",
"react-hook-form": "^7.54.0",
"react-hook-form": "^7.54.2",
"react-icons": "^5.4.0",
"react-markdown": "^9.0.1",
"react-modal": "^3.16.1",
"react-markdown": "^9.0.3",
"react-modal": "^3.16.3",
"react-shepherd": "^6.1.6",
"recharts": "^2.14.1",
"tailwind-merge": "^2.5.5",
"tailwind-merge": "^2.6.0",
"tailwindcss-animate": "^1.0.7",
"uuid": "^11.0.3",
"uuid": "^11.0.4",
"zod": "^3.23.8"
},
"devDependencies": {
@ -95,16 +95,16 @@
"@storybook/test": "^8.3.5",
"@storybook/test-runner": "^0.21.0",
"@types/negotiator": "^0.6.3",
"@types/node": "^22.9.0",
"@types/node": "^22.10.5",
"@types/react": "^18",
"@types/react-dom": "^18",
"@types/react-modal": "^3.16.3",
"axe-playwright": "^2.0.3",
"chromatic": "^11.12.5",
"concurrently": "^9.1.1",
"chromatic": "^11.22.0",
"concurrently": "^9.1.2",
"eslint": "^8",
"eslint-config-next": "15.1.3",
"eslint-plugin-storybook": "^0.11.0",
"eslint-plugin-storybook": "^0.11.2",
"msw": "^2.7.0",
"msw-storybook-addon": "^2.0.3",
"postcss": "^8",

View File

@ -1,6 +1,6 @@
import React from "react";
import type { Metadata } from "next";
import { Inter } from "next/font/google";
import { Inter, Poppins } from "next/font/google";
import { Providers } from "@/app/providers";
import { cn } from "@/lib/utils";
import { Navbar } from "@/components/agptui/Navbar";
@ -10,8 +10,16 @@ import TallyPopupSimple from "@/components/TallyPopup";
import { GoogleAnalytics } from "@next/third-parties/google";
import { Toaster } from "@/components/ui/toaster";
import { IconType } from "@/components/ui/icons";
import { GeistSans } from "geist/font/sans";
import { GeistMono } from "geist/font/mono";
const inter = Inter({ subsets: ["latin"] });
// Fonts
const inter = Inter({ subsets: ["latin"], variable: "--font-inter" });
const poppins = Poppins({
subsets: ["latin"],
weight: ["400", "500", "600", "700"],
variable: "--font-poppins",
});
export const metadata: Metadata = {
title: "NextGen AutoGPT",
@ -24,7 +32,10 @@ export default async function RootLayout({
children: React.ReactNode;
}>) {
return (
<html lang="en">
<html
lang="en"
className={`${GeistSans.variable} ${GeistMono.variable} ${poppins.variable} ${inter.variable}`}
>
<body className={cn("antialiased transition-colors", inter.className)}>
<Providers
attribute="class"

View File

@ -0,0 +1,11 @@
import { APIKeysSection } from "@/components/agptui/composite/APIKeySection";
const ApiKeysPage = () => {
return (
<div className="w-full pr-4 pt-24 md:pt-0">
<APIKeysSection />
</div>
);
};
export default ApiKeysPage;

View File

@ -9,6 +9,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
{ text: "Agent dashboard", href: "/store/agent-dashboard" },
{ text: "Credits", href: "/store/credits" },
{ text: "Integrations", href: "/store/integrations" },
{ text: "API Keys", href: "/store/api_keys" },
{ text: "Profile", href: "/store/profile" },
{ text: "Settings", href: "/store/settings" },
],
@ -18,7 +19,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
return (
<div className="flex min-h-screen w-screen max-w-[1360px] flex-col lg:flex-row">
<Sidebar linkGroups={sidebarLinkGroups} />
<div className="pl-4">{children}</div>
<div className="flex-1 pl-4">{children}</div>
</div>
);
}

View File

@ -253,7 +253,13 @@ export function CustomNode({
!isHidden &&
(isRequired || isAdvancedOpen || isConnected || !isAdvanced) && (
<div key={propKey} data-id={`input-handle-${propKey}`}>
{isConnectable ? (
{isConnectable &&
!(
"oneOf" in propSchema &&
propSchema.oneOf &&
"discriminator" in propSchema &&
propSchema.discriminator
) ? (
<NodeHandle
keyName={propKey}
isConnected={isConnected}

View File

@ -61,7 +61,7 @@ const TallyPopupSimple = () => {
<Button
variant="default"
onClick={resetTutorial}
className="font-inter mb-0 h-14 w-28 rounded-2xl bg-[rgba(65,65,64,1)] text-left text-lg font-medium leading-6"
className="mb-0 h-14 w-28 rounded-2xl bg-[rgba(65,65,64,1)] text-left font-inter text-lg font-medium leading-6"
>
Tutorial
</Button>

View File

@ -2,7 +2,7 @@ import * as React from "react";
import Link from "next/link";
import { Button } from "./Button";
import { Sheet, SheetContent, SheetTrigger } from "@/components/ui/sheet";
import { Menu } from "lucide-react";
import { KeyIcon, Menu } from "lucide-react";
import {
IconDashboardLayout,
IconIntegrations,
@ -74,6 +74,15 @@ export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
Integrations
</div>
</Link>
<Link
href="/store/api_keys"
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
>
<KeyIcon className="h-6 w-6" />
<div className="p-ui-medium text-base font-medium leading-normal">
API Keys
</div>
</Link>
<Link
href="/store/profile"
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
@ -129,6 +138,15 @@ export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
Integrations
</div>
</Link>
<Link
href="/store/api_keys"
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
>
<KeyIcon className="h-6 w-6" strokeWidth={1} />
<div className="p-ui-medium text-base font-medium leading-normal">
API Keys
</div>
</Link>
<Link
href="/store/profile"
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"

View File

@ -0,0 +1,296 @@
"use client";
import { useState, useEffect } from "react";
import { APIKey, APIKeyPermission } from "@/lib/autogpt-server-api/types";
import { LuCopy } from "react-icons/lu";
import { Loader2, MoreVertical } from "lucide-react";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useToast } from "@/components/ui/use-toast";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import { Label } from "@/components/ui/label";
import { Input } from "@/components/ui/input";
import { Checkbox } from "@/components/ui/checkbox";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Badge } from "@/components/ui/badge";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
export function APIKeysSection() {
const [apiKeys, setApiKeys] = useState<APIKey[]>([]);
const [isLoading, setIsLoading] = useState(true);
const [isCreateOpen, setIsCreateOpen] = useState(false);
const [isKeyDialogOpen, setIsKeyDialogOpen] = useState(false);
const [newKeyName, setNewKeyName] = useState("");
const [newKeyDescription, setNewKeyDescription] = useState("");
const [newApiKey, setNewApiKey] = useState("");
const [selectedPermissions, setSelectedPermissions] = useState<
APIKeyPermission[]
>([]);
const { toast } = useToast();
const api = useBackendAPI();
useEffect(() => {
loadAPIKeys();
}, []);
const loadAPIKeys = async () => {
setIsLoading(true);
try {
const keys = await api.listAPIKeys();
setApiKeys(keys.filter((key) => key.status === "ACTIVE"));
} finally {
setIsLoading(false);
}
};
const handleCreateKey = async () => {
try {
const response = await api.createAPIKey(
newKeyName,
selectedPermissions,
newKeyDescription,
);
setNewApiKey(response.plain_text_key);
setIsCreateOpen(false);
setIsKeyDialogOpen(true);
loadAPIKeys();
} catch (error) {
toast({
title: "Error",
description: "Failed to create AutoGPT Platform API key",
variant: "destructive",
});
}
};
const handleCopyKey = () => {
navigator.clipboard.writeText(newApiKey);
toast({
title: "Copied",
description: "AutoGPT Platform API key copied to clipboard",
});
};
const handleRevokeKey = async (keyId: string) => {
try {
await api.revokeAPIKey(keyId);
toast({
title: "Success",
description: "AutoGPT Platform API key revoked successfully",
});
loadAPIKeys();
} catch (error) {
toast({
title: "Error",
description: "Failed to revoke AutoGPT Platform API key",
variant: "destructive",
});
}
};
return (
<Card>
<CardHeader>
<CardTitle>AutoGPT Platform API Keys</CardTitle>
<CardDescription>
Manage your AutoGPT Platform API keys for programmatic access
</CardDescription>
</CardHeader>
<CardContent>
<div className="mb-4 flex justify-end">
<Dialog open={isCreateOpen} onOpenChange={setIsCreateOpen}>
<DialogTrigger asChild>
<Button>Create Key</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>Create New API Key</DialogTitle>
<DialogDescription>
Create a new AutoGPT Platform API key
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="name">Name</Label>
<Input
id="name"
value={newKeyName}
onChange={(e) => setNewKeyName(e.target.value)}
placeholder="My AutoGPT Platform API Key"
/>
</div>
<div className="grid gap-2">
<Label htmlFor="description">Description (Optional)</Label>
<Input
id="description"
value={newKeyDescription}
onChange={(e) => setNewKeyDescription(e.target.value)}
placeholder="Used for..."
/>
</div>
<div className="grid gap-2">
<Label>Permissions</Label>
{Object.values(APIKeyPermission).map((permission) => (
<div
className="flex items-center space-x-2"
key={permission}
>
<Checkbox
id={permission}
checked={selectedPermissions.includes(permission)}
onCheckedChange={(checked) => {
setSelectedPermissions(
checked
? [...selectedPermissions, permission]
: selectedPermissions.filter(
(p) => p !== permission,
),
);
}}
/>
<Label htmlFor={permission}>{permission}</Label>
</div>
))}
</div>
</div>
<DialogFooter>
<Button
variant="outline"
onClick={() => setIsCreateOpen(false)}
>
Cancel
</Button>
<Button onClick={handleCreateKey}>Create</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={isKeyDialogOpen} onOpenChange={setIsKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>AutoGPT Platform API Key Created</DialogTitle>
<DialogDescription>
Please copy your AutoGPT API key now. You won&apos;t be able
to see it again!
</DialogDescription>
</DialogHeader>
<div className="flex items-center space-x-2">
<code className="flex-1 rounded-md bg-secondary p-2 text-sm">
{newApiKey}
</code>
<Button size="icon" variant="outline" onClick={handleCopyKey}>
<LuCopy className="h-4 w-4" />
</Button>
</div>
<DialogFooter>
<Button onClick={() => setIsKeyDialogOpen(false)}>Close</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
{isLoading ? (
<div className="flex justify-center p-4">
<Loader2 className="h-6 w-6 animate-spin" />
</div>
) : (
apiKeys.length > 0 && (
<Table>
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>API Key</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead>Last Used</TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{apiKeys.map((key) => (
<TableRow key={key.id}>
<TableCell>{key.name}</TableCell>
<TableCell>
<div className="rounded-md border p-1 px-2 text-xs">
{`${key.prefix}******************${key.postfix}`}
</div>
</TableCell>
<TableCell>
<Badge
variant={
key.status === "ACTIVE" ? "default" : "destructive"
}
className={
key.status === "ACTIVE"
? "border-green-600 bg-green-100 text-green-800"
: "border-red-600 bg-red-100 text-red-800"
}
>
{key.status}
</Badge>
</TableCell>
<TableCell>
{new Date(key.created_at).toLocaleDateString()}
</TableCell>
<TableCell>
{key.last_used_at
? new Date(key.last_used_at).toLocaleDateString()
: "Never"}
</TableCell>
<TableCell>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button variant="ghost" size="sm">
<MoreVertical className="h-4 w-4" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
className="text-destructive"
onClick={() => handleRevokeKey(key.id)}
>
Revoke
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)
)}
</CardContent>
</Card>
);
}

View File

@ -6,9 +6,11 @@ import {
Carousel,
CarouselContent,
CarouselItem,
CarouselPrevious,
CarouselNext,
CarouselIndicator,
} from "@/components/ui/carousel";
import { useCallback, useState } from "react";
import { IconLeftArrow, IconRightArrow } from "@/components/ui/icons";
import { useRouter } from "next/navigation";
const BACKGROUND_COLORS = [
@ -63,27 +65,24 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
return (
<div className="flex w-full flex-col items-center justify-center">
<div className="w-full">
<h2 className="font-poppins mb-8 text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
<div className="w-[99vw]">
<h2 className="font-poppins mx-auto mb-8 max-w-[1360px] px-4 text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
Featured agents
</h2>
<div>
<div className="w-[99vw] pb-[60px]">
<Carousel
className="mx-auto pb-10"
opts={{
loop: true,
startIndex: currentSlide,
duration: 500,
align: "start",
align: "center",
containScroll: "trimSnaps",
}}
className="w-full overflow-x-hidden"
>
<CarouselContent className="transition-transform duration-500">
<CarouselContent className="ml-[calc(50vw-690px)]">
{featuredAgents.map((agent, index) => (
<CarouselItem
key={index}
className="max-w-[460px] flex-[0_0_auto] pr-8"
className="max-w-[460px] flex-[0_0_auto]"
>
<FeaturedStoreCard
agentName={agent.agent_name}
@ -99,37 +98,13 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
</CarouselItem>
))}
</CarouselContent>
<div className="relative mx-auto w-full max-w-[1360px] pl-4">
<CarouselIndicator />
<CarouselPrevious afterClick={handlePrevSlide} />
<CarouselNext afterClick={handleNextSlide} />
</div>
</Carousel>
</div>
<div className="mt-8 flex w-full items-center justify-between">
<div className="flex h-3 items-center gap-2">
{featuredAgents.map((_, index) => (
<div
key={index}
className={`${
currentSlide === index
? "h-3 w-[52px] rounded-[39px] bg-neutral-800 transition-all duration-500 dark:bg-neutral-200"
: "h-3 w-3 rounded-full bg-neutral-300 transition-all duration-500 dark:bg-neutral-600"
}`}
/>
))}
</div>
<div className="mb-[60px] flex items-center gap-3">
<button
onClick={handlePrevSlide}
className="mb:h-12 mb:w-12 flex h-10 w-10 items-center justify-center rounded-full border border-neutral-400 bg-white dark:border-neutral-600 dark:bg-neutral-800"
>
<IconLeftArrow className="h-8 w-8 text-neutral-800 dark:text-neutral-200" />
</button>
<button
onClick={handleNextSlide}
className="mb:h-12 mb:w-12 flex h-10 w-10 items-center justify-center rounded-full border border-neutral-900 bg-white dark:border-neutral-600 dark:bg-neutral-800"
>
<IconRightArrow className="h-8 w-8 text-neutral-800 dark:text-neutral-200" />
</button>
</div>
</div>
</div>
</div>
);

View File

@ -7,8 +7,15 @@ import SchemaTooltip from "@/components/SchemaTooltip";
import useCredentials from "@/hooks/useCredentials";
import { zodResolver } from "@hookform/resolvers/zod";
import { NotionLogoIcon } from "@radix-ui/react-icons";
import { FaDiscord, FaGithub, FaGoogle, FaMedium, FaKey } from "react-icons/fa";
import { FC, useState } from "react";
import {
FaDiscord,
FaGithub,
FaTwitter,
FaGoogle,
FaMedium,
FaKey,
} from "react-icons/fa";
import { FC, useMemo, useState } from "react";
import {
CredentialsMetaInput,
CredentialsProviderName,
@ -53,6 +60,7 @@ export const providerIcons: Record<
google: FaGoogle,
groq: fallbackIcon,
notion: NotionLogoIcon,
nvidia: fallbackIcon,
discord: FaDiscord,
d_id: fallbackIcon,
google_maps: FaGoogle,
@ -68,6 +76,7 @@ export const providerIcons: Record<
replicate: fallbackIcon,
fal: fallbackIcon,
revid: fallbackIcon,
twitter: FaTwitter,
unreal_speech: fallbackIcon,
exa: fallbackIcon,
hubspot: fallbackIcon,

View File

@ -28,6 +28,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
jina: "Jina",
medium: "Medium",
notion: "Notion",
nvidia: "Nvidia",
ollama: "Ollama",
openai: "OpenAI",
openweathermap: "OpenWeatherMap",
@ -37,6 +38,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
replicate: "Replicate",
fal: "FAL",
revid: "Rev.ID",
twitter: "Twitter",
unreal_speech: "Unreal Speech",
exa: "Exa",
hubspot: "Hubspot",

View File

@ -17,6 +17,7 @@ import {
BlockIOStringSubSchema,
BlockIONumberSubSchema,
BlockIOBooleanSubSchema,
BlockIOSimpleTypeSubSchema,
} from "@/lib/autogpt-server-api/types";
import React, { FC, useCallback, useEffect, useMemo, useState } from "react";
import { Button } from "./ui/button";
@ -40,6 +41,7 @@ import { LocalValuedInput } from "./ui/input";
import NodeHandle from "./NodeHandle";
import { ConnectionData } from "./CustomNode";
import { CredentialsInput } from "./integrations/credentials-input";
import { MultiSelect } from "./ui/multiselect-input";
type NodeObjectInputTreeProps = {
nodeId: string;
@ -101,6 +103,92 @@ const NodeObjectInputTree: FC<NodeObjectInputTreeProps> = ({
export default NodeObjectInputTree;
const NodeImageInput: FC<{
selfKey: string;
schema: BlockIOStringSubSchema;
value?: string;
error?: string;
handleInputChange: NodeObjectInputTreeProps["handleInputChange"];
className?: string;
displayName: string;
}> = ({
selfKey,
schema,
value = "",
error,
handleInputChange,
className,
displayName,
}) => {
const handleFileChange = useCallback(
async (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0];
if (!file) return;
// Validate file type
if (!file.type.startsWith("image/")) {
console.error("Please upload an image file");
return;
}
// Convert to base64
const reader = new FileReader();
reader.onload = (e) => {
const base64String = (e.target?.result as string).split(",")[1];
handleInputChange(selfKey, base64String);
};
reader.readAsDataURL(file);
},
[selfKey, handleInputChange],
);
return (
<div className={cn("flex flex-col gap-2", className)}>
<div className="nodrag flex flex-col gap-2">
<div className="flex items-center gap-2">
<Button
variant="outline"
onClick={() =>
document.getElementById(`${selfKey}-upload`)?.click()
}
className="w-full"
>
{value ? "Change Image" : `Upload ${displayName}`}
</Button>
{value && (
<Button
variant="ghost"
className="text-red-500 hover:text-red-700"
onClick={() => handleInputChange(selfKey, "")}
>
<Cross2Icon className="h-4 w-4" />
</Button>
)}
</div>
<input
id={`${selfKey}-upload`}
type="file"
accept="image/*"
onChange={handleFileChange}
className="hidden"
/>
{value && (
<div className="relative mt-2 rounded-md border border-gray-300 p-2 dark:border-gray-600">
<img
src={`data:image/jpeg;base64,${value}`}
alt="Preview"
className="max-h-32 w-full rounded-md object-contain"
/>
</div>
)}
</div>
{error && <span className="error-message">{error}</span>}
</div>
);
};
const NodeDateTimeInput: FC<{
selfKey: string;
schema: BlockIOStringSubSchema;
@ -225,6 +313,8 @@ export const NodeGenericInputField: FC<{
);
}
console.log("propSchema", propSchema);
if ("properties" in propSchema) {
// Render a multi-select for all-boolean sub-schemas with more than 3 properties
if (
@ -290,12 +380,53 @@ export const NodeGenericInputField: FC<{
}
if ("anyOf" in propSchema) {
// Optional oneOf
if (
"oneOf" in propSchema.anyOf[0] &&
propSchema.anyOf[0].oneOf &&
"discriminator" in propSchema.anyOf[0] &&
propSchema.anyOf[0].discriminator
) {
return (
<NodeOneOfDiscriminatorField
nodeId={nodeId}
propKey={propKey}
propSchema={propSchema.anyOf[0]}
currentValue={currentValue}
errors={errors}
connections={connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
className={className}
displayName={displayName}
/>
);
}
// optional items
const types = propSchema.anyOf.map((s) =>
"type" in s ? s.type : undefined,
);
if (types.includes("string") && types.includes("null")) {
// optional string
// optional string and datetime
if (
"format" in propSchema.anyOf[0] &&
propSchema.anyOf[0].format === "date-time"
) {
return (
<NodeDateTimeInput
selfKey={propKey}
schema={propSchema.anyOf[0]}
value={currentValue}
error={errors[propKey]}
className={className}
displayName={displayName}
handleInputChange={handleInputChange}
/>
);
}
return (
<NodeStringInput
selfKey={propKey}
@ -356,6 +487,42 @@ export const NodeGenericInputField: FC<{
/>
);
} else if (types.includes("object") && types.includes("null")) {
// rendering optional mutliselect
if (
Object.values(
(propSchema.anyOf[0] as BlockIOObjectSubSchema).properties,
).every(
(subSchema) => "type" in subSchema && subSchema.type == "boolean",
) &&
Object.keys((propSchema.anyOf[0] as BlockIOObjectSubSchema).properties)
.length >= 1
) {
const options = Object.keys(
(propSchema.anyOf[0] as BlockIOObjectSubSchema).properties,
);
const selectedKeys = Object.entries(currentValue || {})
.filter(([_, v]) => v)
.map(([k, _]) => k);
return (
<NodeMultiSelectInput
selfKey={propKey}
schema={propSchema.anyOf[0] as BlockIOObjectSubSchema}
selection={selectedKeys}
error={errors[propKey]}
className={className}
displayName={displayName}
handleInputChange={(key, selection) => {
handleInputChange(
key,
Object.fromEntries(
options.map((option) => [option, selection.includes(option)]),
),
);
}}
/>
);
}
return (
<NodeKeyValueInput
nodeId={nodeId}
@ -418,6 +585,19 @@ export const NodeGenericInputField: FC<{
switch (propSchema.type) {
case "string":
if ("image_upload" in propSchema && propSchema.image_upload === true) {
return (
<NodeImageInput
selfKey={propKey}
schema={propSchema}
value={currentValue}
error={errors[propKey]}
className={className}
displayName={displayName}
handleInputChange={handleInputChange}
/>
);
}
if ("format" in propSchema && propSchema.format === "date-time") {
return (
<NodeDateTimeInput
@ -523,7 +703,7 @@ const NodeOneOfDiscriminatorField: FC<{
propSchema: any;
currentValue?: any;
errors: { [key: string]: string | undefined };
connections: any;
connections: ConnectionData;
handleInputChange: (key: string, value: any) => void;
handleInputClick: (key: string) => void;
className?: string;
@ -538,7 +718,6 @@ const NodeOneOfDiscriminatorField: FC<{
handleInputChange,
handleInputClick,
className,
displayName,
}) => {
const discriminator = propSchema.discriminator;
@ -554,7 +733,7 @@ const NodeOneOfDiscriminatorField: FC<{
return {
value: variantDiscValue,
schema: variant,
schema: variant as BlockIOSubSchema,
};
})
.filter((v: any) => v.value != null);
@ -585,8 +764,24 @@ const NodeOneOfDiscriminatorField: FC<{
(opt: any) => opt.value === chosenType,
)?.schema;
function getEntryKey(key: string): string {
// use someKey for handle purpose (not childKey)
return `${propKey}_#_${key}`;
}
function isConnected(key: string): boolean {
return connections.some(
(c) => c.targetHandle === getEntryKey(key) && c.target === nodeId,
);
}
return (
<div className={cn("flex flex-col space-y-2", className)}>
<div
className={cn(
"flex min-w-[400px] max-w-[95%] flex-col space-y-4",
className,
)}
>
<Select value={chosenType || ""} onValueChange={handleVariantChange}>
<SelectTrigger className="w-full">
<SelectValue placeholder="Select a type..." />
@ -607,32 +802,36 @@ const NodeOneOfDiscriminatorField: FC<{
if (someKey === "discriminator") {
return null;
}
const childKey = propKey ? `${propKey}.${someKey}` : someKey;
const childKey = propKey ? `${propKey}.${someKey}` : someKey; // for history redo/undo purpose
return (
<div
key={childKey}
className="flex w-full flex-row justify-between space-y-2"
className="mb-4 flex w-full flex-col justify-between space-y-2"
>
<span className="mr-2 mt-3 dark:text-gray-300">
{(childSchema as BlockIOSubSchema).title ||
beautifyString(someKey)}
</span>
<NodeGenericInputField
nodeId={nodeId}
key={propKey}
propKey={childKey}
propSchema={childSchema as BlockIOSubSchema}
currentValue={
currentValue ? currentValue[someKey] : undefined
}
errors={errors}
connections={connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
displayName={
chosenVariantSchema.title || beautifyString(someKey)
}
<NodeHandle
keyName={getEntryKey(someKey)}
schema={childSchema as BlockIOSubSchema}
isConnected={isConnected(getEntryKey(someKey))}
isRequired={false}
side="left"
/>
{!isConnected(someKey) && (
<NodeGenericInputField
nodeId={nodeId}
key={propKey}
propKey={childKey}
propSchema={childSchema as BlockIOSubSchema}
currentValue={
currentValue ? currentValue[someKey] : undefined
}
errors={errors}
connections={connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
displayName={beautifyString(someKey)}
/>
)}
</div>
);
},
@ -827,6 +1026,13 @@ const NodeKeyValueInput: FC<{
);
};
// Checking if schema is type of string
function isStringSubSchema(
schema: BlockIOSimpleTypeSubSchema,
): schema is BlockIOStringSubSchema {
return "type" in schema && schema.type === "string";
}
const NodeArrayInput: FC<{
nodeId: string;
selfKey: string;

View File

@ -1,10 +1,11 @@
// This file has been updated for the Store's "Featured Agent Section". If you want to add Carousel, keep these components in mind: CarouselIndicator, CarouselPrevious, and CarouselNext.
"use client";
import * as React from "react";
import useEmblaCarousel, {
type UseEmblaCarouselType,
} from "embla-carousel-react";
import { ArrowLeft, ArrowRight } from "lucide-react";
import { ChevronLeft, ChevronRight } from "lucide-react";
import { cn } from "@/lib/utils";
import { Button } from "@/components/ui/button";
@ -196,67 +197,137 @@ CarouselItem.displayName = "CarouselItem";
const CarouselPrevious = React.forwardRef<
HTMLButtonElement,
React.ComponentProps<typeof Button>
>(({ className, variant = "outline", size = "icon", ...props }, ref) => {
const { orientation, scrollPrev, canScrollPrev } = useCarousel();
React.ComponentProps<typeof Button> & { afterClick?: () => void }
>(
(
{ className, afterClick, variant = "outline", size = "icon", ...props },
ref,
) => {
const { orientation, scrollPrev, canScrollPrev } = useCarousel();
return (
<Button
ref={ref}
variant={variant}
size={size}
className={cn(
"absolute h-8 w-8 rounded-full",
orientation === "horizontal"
? "-left-12 top-1/2 -translate-y-1/2"
: "-top-12 left-1/2 -translate-x-1/2 rotate-90",
className,
)}
disabled={!canScrollPrev}
onClick={scrollPrev}
{...props}
>
<ArrowLeft className="h-4 w-4" />
<span className="sr-only">Previous slide</span>
</Button>
);
});
return (
<Button
ref={ref}
variant={variant}
size={size}
className={cn(
"absolute h-[52px] w-[52px] rounded-full",
orientation === "horizontal"
? "-bottom-20 right-24 -translate-y-1/2"
: "-top-12 left-1/2 -translate-x-1/2 rotate-90",
className,
)}
disabled={!canScrollPrev}
onClick={() => {
scrollPrev();
if (afterClick) {
afterClick();
}
}}
{...props}
>
<ChevronLeft className="h-8 w-8" strokeWidth={1.25} />
<span className="sr-only">Previous slide</span>
</Button>
);
},
);
CarouselPrevious.displayName = "CarouselPrevious";
const CarouselNext = React.forwardRef<
HTMLButtonElement,
React.ComponentProps<typeof Button>
>(({ className, variant = "outline", size = "icon", ...props }, ref) => {
const { orientation, scrollNext, canScrollNext } = useCarousel();
React.ComponentProps<typeof Button> & { afterClick?: () => void }
>(
(
{ className, afterClick, variant = "outline", size = "icon", ...props },
ref,
) => {
const { orientation, scrollNext, canScrollNext } = useCarousel();
const handleClick = () => {
scrollNext();
if (afterClick) {
afterClick();
}
};
return (
<Button
ref={ref}
variant={variant}
size={size}
className={cn(
"absolute h-[52px] w-[52px] rounded-full",
orientation === "horizontal"
? "-bottom-20 right-4 -translate-y-1/2"
: "-bottom-12 left-1/2 -translate-x-1/2 rotate-90",
className,
)}
disabled={!canScrollNext}
onClick={handleClick}
{...props}
>
<ChevronRight className="h-8 w-8" strokeWidth={1.25} />
<span className="sr-only">Next slide</span>
</Button>
);
},
);
CarouselNext.displayName = "CarouselNext";
const CarouselIndicator = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => {
const { api } = useCarousel();
const [selectedIndex, setSelectedIndex] = React.useState(0);
const [scrollSnaps, setScrollSnaps] = React.useState<number[]>([]);
const scrollTo = React.useCallback(
(index: number) => {
api?.scrollTo(index);
},
[api],
);
React.useEffect(() => {
if (!api) return;
setScrollSnaps(api.scrollSnapList());
api.on("select", () => {
setSelectedIndex(api.selectedScrollSnap());
});
}, [api]);
return (
<Button
<div
ref={ref}
variant={variant}
size={size}
className={cn(
"absolute h-8 w-8 rounded-full",
orientation === "horizontal"
? "-right-12 top-1/2 -translate-y-1/2"
: "-bottom-12 left-1/2 -translate-x-1/2 rotate-90",
className,
)}
disabled={!canScrollNext}
onClick={scrollNext}
className={cn("relative top-10 flex h-3 items-center gap-2", className)}
{...props}
>
<ArrowRight className="h-4 w-4" />
<span className="sr-only">Next slide</span>
</Button>
{scrollSnaps.map((_, index) => (
<div
key={index}
onClick={() => scrollTo(index)}
className={cn(
selectedIndex === index
? "h-3 w-[52px] rounded-[39px] bg-neutral-800 transition-all duration-500 dark:bg-neutral-200"
: "h-3 w-3 rounded-full bg-neutral-300 transition-all duration-500 dark:bg-neutral-600",
"cursor-pointer",
)}
/>
))}
</div>
);
});
CarouselNext.displayName = "CarouselNext";
CarouselIndicator.displayName = "CarouselIndicator";
export {
type CarouselApi,
Carousel,
CarouselContent,
CarouselItem,
CarouselIndicator,
CarouselPrevious,
CarouselNext,
};

View File

@ -17,6 +17,9 @@ const isValidVideoUrl = (url: string): boolean => {
};
const isValidImageUrl = (url: string): boolean => {
if (url.startsWith("data:image/")) {
return true;
}
const imageExtensions = /\.(jpeg|jpg|gif|png|svg|webp)$/i;
const cleanedUrl = url.split("?")[0];
return imageExtensions.test(cleanedUrl);
@ -50,19 +53,21 @@ const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
);
};
const ImageRenderer: React.FC<{ imageUrl: string }> = ({ imageUrl }) => (
<div className="w-full p-2">
<picture>
<img
src={imageUrl}
alt="Image"
className="h-auto max-w-full"
width="100%"
height="auto"
/>
</picture>
</div>
);
const ImageRenderer: React.FC<{ imageUrl: string }> = ({ imageUrl }) => {
return (
<div className="w-full p-2">
<picture>
<img
src={imageUrl}
alt="Image"
className="h-auto max-w-full"
width="100%"
height="auto"
/>
</picture>
</div>
);
};
const AudioRenderer: React.FC<{ audioUrl: string }> = ({ audioUrl }) => (
<div className="w-full p-2">
@ -92,6 +97,9 @@ export const ContentRenderer: React.FC<{
truncateLongData?: boolean;
}> = ({ value, truncateLongData }) => {
if (typeof value === "string") {
if (value.startsWith("data:image/")) {
return <ImageRenderer imageUrl={value} />;
}
if (isValidVideoUrl(value)) {
return <VideoRenderer videoUrl={value} />;
} else if (isValidImageUrl(value)) {

View File

@ -29,6 +29,9 @@ import {
StoreReview,
ScheduleCreatable,
Schedule,
APIKeyPermission,
CreateAPIKeyResponse,
APIKey,
} from "./types";
import { createBrowserClient } from "@supabase/ssr";
import getServerSupabase from "../supabase/getServerSupabase";
@ -229,6 +232,36 @@ export default class BackendAPI {
);
}
// API Key related requests
async createAPIKey(
name: string,
permissions: APIKeyPermission[],
description?: string,
): Promise<CreateAPIKeyResponse> {
return this._request("POST", "/api-keys", {
name,
permissions,
description,
});
}
async listAPIKeys(): Promise<APIKey[]> {
return this._get("/api-keys");
}
async revokeAPIKey(keyId: string): Promise<APIKey> {
return this._request("DELETE", `/api-keys/${keyId}`);
}
async updateAPIKeyPermissions(
keyId: string,
permissions: APIKeyPermission[],
): Promise<APIKey> {
return this._request("PUT", `/api-keys/${keyId}/permissions`, {
permissions,
});
}
/**
* @returns `true` if a ping event was received, `false` if provider doesn't support pinging but the webhook exists.
* @throws `Error` if the webhook does not exist.

View File

@ -41,7 +41,7 @@ export type BlockIOSubSchema =
| BlockIOSimpleTypeSubSchema
| BlockIOCombinedTypeSubSchema;
type BlockIOSimpleTypeSubSchema =
export type BlockIOSimpleTypeSubSchema =
| BlockIOObjectSubSchema
| BlockIOCredentialsSubSchema
| BlockIOKVSubSchema
@ -113,6 +113,7 @@ export const PROVIDER_NAMES = {
JINA: "jina",
MEDIUM: "medium",
NOTION: "notion",
NVIDIA: "nvidia",
OLLAMA: "ollama",
OPENAI: "openai",
OPENWEATHERMAP: "openweathermap",
@ -125,6 +126,7 @@ export const PROVIDER_NAMES = {
UNREAL_SPEECH: "unreal_speech",
EXA: "exa",
HUBSPOT: "hubspot",
TWITTER: "twitter",
} as const;
// --8<-- [end:BlockIOCredentialsSubSchema]
@ -512,3 +514,36 @@ export type StoreReviewCreate = {
score: number;
comments?: string;
};
// API Key Types
export enum APIKeyPermission {
EXECUTE_GRAPH = "EXECUTE_GRAPH",
READ_GRAPH = "READ_GRAPH",
EXECUTE_BLOCK = "EXECUTE_BLOCK",
READ_BLOCK = "READ_BLOCK",
}
export enum APIKeyStatus {
ACTIVE = "ACTIVE",
REVOKED = "REVOKED",
SUSPENDED = "SUSPENDED",
}
export interface APIKey {
id: string;
name: string;
prefix: string;
postfix: string;
status: APIKeyStatus;
permissions: APIKeyPermission[];
created_at: string;
last_used_at?: string;
revoked_at?: string;
description?: string;
}
export interface CreateAPIKeyResponse {
api_key: APIKey;
plain_text_key: string;
}

Some files were not shown because too many files have changed in this diff Show More