Merge branch 'dev' into kpczerwinski/open-2263-change-url-from-store-to-marketplace
This commit is contained in:
commit
f5e7a1d4a7
|
@ -0,0 +1,76 @@
|
|||
from typing import Annotated, Any, Literal, Optional, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr, field_serializer
|
||||
|
||||
|
||||
class _BaseCredentials(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
provider: str
|
||||
title: Optional[str]
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
if isinstance(value, SecretStr):
|
||||
return value.get_secret_value()
|
||||
return value
|
||||
|
||||
|
||||
class OAuth2Credentials(_BaseCredentials):
|
||||
type: Literal["oauth2"] = "oauth2"
|
||||
username: Optional[str]
|
||||
"""Username of the third-party service user that these credentials belong to"""
|
||||
access_token: SecretStr
|
||||
access_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
|
||||
refresh_token: Optional[SecretStr]
|
||||
refresh_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
|
||||
scopes: list[str]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def bearer(self) -> str:
|
||||
return f"Bearer {self.access_token.get_secret_value()}"
|
||||
|
||||
|
||||
class APIKeyCredentials(_BaseCredentials):
|
||||
type: Literal["api_key"] = "api_key"
|
||||
api_key: SecretStr
|
||||
expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
def bearer(self) -> str:
|
||||
return f"Bearer {self.api_key.get_secret_value()}"
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
token: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
code_verifier: Optional[str] = None
|
||||
scopes: list[str]
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
integration_credentials: list[Credentials] = Field(default_factory=list)
|
||||
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UserMetadataRaw(TypedDict, total=False):
|
||||
integration_credentials: list[dict]
|
||||
integration_oauth_states: list[dict]
|
||||
|
||||
|
||||
class UserIntegrations(BaseModel):
|
||||
credentials: list[Credentials] = Field(default_factory=list)
|
||||
oauth_states: list[OAuthState] = Field(default_factory=list)
|
|
@ -58,6 +58,21 @@ GITHUB_CLIENT_SECRET=
|
|||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
|
||||
# Twitter (X) OAuth 2.0 with PKCE Configuration
|
||||
# 1. Create a Twitter Developer Account:
|
||||
# - Visit https://developer.x.com/en and sign up
|
||||
# 2. Set up your application:
|
||||
# - Navigate to Developer Portal > Projects > Create Project
|
||||
# - Add a new app to your project
|
||||
# 3. Configure app settings:
|
||||
# - App Permissions: Read + Write + Direct Messages
|
||||
# - App Type: Web App, Automated App or Bot
|
||||
# - OAuth 2.0 Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# - Save your Client ID and Client Secret below
|
||||
TWITTER_CLIENT_ID=
|
||||
TWITTER_CLIENT_SECRET=
|
||||
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
|
@ -106,6 +121,18 @@ REPLICATE_API_KEY=
|
|||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
|
|
|
@ -241,7 +241,7 @@ class AgentOutputBlock(Block):
|
|||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value.",
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
|
|
@ -56,15 +56,24 @@ class SendWebRequestBlock(Block):
|
|||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if isinstance(input_data.body, str):
|
||||
input_data.body = json.loads(input_data.body)
|
||||
body = input_data.body
|
||||
|
||||
if input_data.json_format:
|
||||
if isinstance(body, str):
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
body = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON and just plain text,
|
||||
# we should send it as plain text instead
|
||||
input_data.json_format = False
|
||||
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=input_data.body if input_data.json_format else None,
|
||||
data=input_data.body if not input_data.json_format else None,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
|
|
|
@ -26,8 +26,10 @@ from backend.data.model import (
|
|||
)
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
fmt = TextFormatter()
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.ANTHROPIC,
|
||||
|
@ -109,6 +111,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
|||
LLAMA3_1_70B = "llama-3.1-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
|
@ -163,6 +166,7 @@ MODEL_METADATA = {
|
|||
# Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
|
||||
|
@ -234,7 +238,9 @@ class AIStructuredResponseGeneratorBlock(Block):
|
|||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
|
@ -448,8 +454,8 @@ class AIStructuredResponseGeneratorBlock(Block):
|
|||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = input_data.prompt.format(**values)
|
||||
input_data.sys_prompt = input_data.sys_prompt.format(**values)
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
@ -576,7 +582,9 @@ class AITextGeneratorBlock(Block):
|
|||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
NvidiaCredentials = APIKeyCredentials
|
||||
NvidiaCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.NVIDIA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="nvidia",
|
||||
api_key=SecretStr("mock-nvidia-api-key"),
|
||||
title="Mock Nvidia API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def NvidiaCredentialsField() -> NvidiaCredentialsInput:
|
||||
"""Creates an Nvidia credentials input on a block."""
|
||||
return CredentialsField(description="The Nvidia integration requires an API Key.")
|
|
@ -0,0 +1,90 @@
|
|||
from backend.blocks.nvidia._auth import (
|
||||
NvidiaCredentials,
|
||||
NvidiaCredentialsField,
|
||||
NvidiaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class NvidiaDeepfakeDetectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
|
||||
image_base64: str = SchemaField(
|
||||
description="Image to analyze for deepfakes", image_upload=True
|
||||
)
|
||||
return_image: bool = SchemaField(
|
||||
description="Whether to return the processed image with markings",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
|
||||
default="",
|
||||
)
|
||||
image: str = SchemaField(
|
||||
description="Processed image with detection markings (if return_image=True)",
|
||||
default="",
|
||||
image_output=True,
|
||||
)
|
||||
is_deepfake: float = SchemaField(
|
||||
description="Probability that the image is a deepfake (0-1)",
|
||||
default=0.0,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8c7d0d67-e79c-44f6-92a1-c2600c8aac7f",
|
||||
description="Detects potential deepfakes in images using Nvidia's AI API",
|
||||
categories={BlockCategory.SAFETY},
|
||||
input_schema=NvidiaDeepfakeDetectBlock.Input,
|
||||
output_schema=NvidiaDeepfakeDetectBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: NvidiaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://ai.api.nvidia.com/v1/cv/hive/deepfake-image-detection"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
image_data = f"data:image/jpeg;base64,{input_data.image_base64}"
|
||||
|
||||
payload = {
|
||||
"input": [image_data],
|
||||
"return_image": input_data.return_image,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
result = data.get("data", [{}])[0]
|
||||
|
||||
# Get deepfake probability from first bounding box if any
|
||||
deepfake_prob = 0.0
|
||||
if result.get("bounding_boxes"):
|
||||
deepfake_prob = result["bounding_boxes"][0].get("is_deepfake", 0.0)
|
||||
|
||||
yield "status", result.get("status", "ERROR")
|
||||
yield "is_deepfake", deepfake_prob
|
||||
|
||||
if input_data.return_image:
|
||||
image_data = result.get("image", "")
|
||||
output_data = f"data:image/jpeg;base64,{image_data}"
|
||||
yield "image", output_data
|
||||
else:
|
||||
yield "image", ""
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "status", "ERROR"
|
||||
yield "is_deepfake", 0.0
|
||||
yield "image", ""
|
|
@ -141,10 +141,10 @@ class ExtractTextInformationBlock(Block):
|
|||
class FillTextTemplateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Values (dict) to be used in format"
|
||||
description="Values (dict) to be used in format. These values can be used by putting them in double curly braces in the format template. e.g. {{value_name}}.",
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`"
|
||||
description="Template to format the text using `values`. Use Jinja2 syntax."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
|
@ -160,7 +160,7 @@ class FillTextTemplateBlock(Block):
|
|||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "hello": "Hello", "world": "World!"},
|
||||
"format": "{hello}, {world} {{name}}",
|
||||
"format": "{{hello}}, {{ world }} {{name}}",
|
||||
},
|
||||
{
|
||||
"values": {"list": ["Hello", " World!"]},
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
ProviderName,
|
||||
)
|
||||
from backend.integrations.oauth.twitter import TwitterOAuthHandler
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# --8<-- [start:TwitterOAuthIsConfigured]
|
||||
secrets = Secrets()
|
||||
TWITTER_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.twitter_client_id and secrets.twitter_client_secret
|
||||
)
|
||||
# --8<-- [end:TwitterOAuthIsConfigured]
|
||||
|
||||
TwitterCredentials = OAuth2Credentials
|
||||
TwitterCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.TWITTER], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
# Currently, We are getting all the permission from the Twitter API initally
|
||||
# In future, If we need to add incremental permission, we can use these requested_scopes
|
||||
def TwitterCredentialsField(scopes: list[str]) -> TwitterCredentialsInput:
|
||||
"""
|
||||
Creates a Twitter credentials input on a block.
|
||||
|
||||
Params:
|
||||
scopes: The authorization scopes needed for the block to work.
|
||||
"""
|
||||
return CredentialsField(
|
||||
# required_scopes=set(scopes),
|
||||
required_scopes=set(TwitterOAuthHandler.DEFAULT_SCOPES + scopes),
|
||||
description="The Twitter integration requires OAuth2 authentication.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="twitter",
|
||||
access_token=SecretStr("mock-twitter-access-token"),
|
||||
refresh_token=SecretStr("mock-twitter-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=["tweet.read", "tweet.write", "users.read", "offline.access"],
|
||||
title="Mock Twitter OAuth2 Credentials",
|
||||
username="mock-twitter-username",
|
||||
refresh_token_expires_at=1234567890,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
|
@ -0,0 +1,418 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from backend.blocks.twitter._mappers import (
|
||||
get_backend_expansion,
|
||||
get_backend_field,
|
||||
get_backend_list_expansion,
|
||||
get_backend_list_field,
|
||||
get_backend_media_field,
|
||||
get_backend_place_field,
|
||||
get_backend_poll_field,
|
||||
get_backend_space_expansion,
|
||||
get_backend_space_field,
|
||||
get_backend_user_field,
|
||||
)
|
||||
from backend.blocks.twitter._types import ( # DMEventFieldFilter,
|
||||
DMEventExpansionFilter,
|
||||
DMEventTypeFilter,
|
||||
DMMediaFieldFilter,
|
||||
DMTweetFieldFilter,
|
||||
ExpansionFilter,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
SpaceExpansionsFilter,
|
||||
SpaceFieldsFilter,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetReplySettingsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
|
||||
|
||||
# Common Builder
|
||||
class TweetExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: ExpansionFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(
|
||||
[get_backend_expansion(exp) for exp in filtered_expansions]
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_media_fields(self, media_fields: TweetMediaFieldsFilter | None):
|
||||
if media_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in media_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["media.fields"] = ",".join(
|
||||
[get_backend_media_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_place_fields(self, place_fields: TweetPlaceFieldsFilter | None):
|
||||
if place_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in place_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["place.fields"] = ",".join(
|
||||
[get_backend_place_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_poll_fields(self, poll_fields: TweetPollFieldsFilter | None):
|
||||
if poll_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in poll_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["poll.fields"] = ",".join(
|
||||
[get_backend_poll_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_tweet_fields(self, tweet_fields: TweetFieldsFilter | None):
|
||||
if tweet_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in tweet_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["tweet.fields"] = ",".join(
|
||||
[get_backend_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class UserExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: UserExpansionsFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(filtered_expansions)
|
||||
return self
|
||||
|
||||
def add_tweet_fields(self, tweet_fields: TweetFieldsFilter | None):
|
||||
if tweet_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in tweet_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["tweet.fields"] = ",".join(
|
||||
[get_backend_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class ListExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: ListExpansionsFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(
|
||||
[get_backend_list_expansion(exp) for exp in filtered_expansions]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_list_fields(self, list_fields: ListFieldsFilter | None):
|
||||
if list_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in list_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["list.fields"] = ",".join(
|
||||
[get_backend_list_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class SpaceExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: SpaceExpansionsFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(
|
||||
[get_backend_space_expansion(exp) for exp in filtered_expansions]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_space_fields(self, space_fields: SpaceFieldsFilter | None):
|
||||
if space_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in space_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["space.fields"] = ",".join(
|
||||
[get_backend_space_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class TweetDurationBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_start_time(self, start_time: datetime | None):
|
||||
if start_time:
|
||||
self.params["start_time"] = start_time
|
||||
return self
|
||||
|
||||
def add_end_time(self, end_time: datetime | None):
|
||||
if end_time:
|
||||
self.params["end_time"] = end_time
|
||||
return self
|
||||
|
||||
def add_since_id(self, since_id: str | None):
|
||||
if since_id:
|
||||
self.params["since_id"] = since_id
|
||||
return self
|
||||
|
||||
def add_until_id(self, until_id: str | None):
|
||||
if until_id:
|
||||
self.params["until_id"] = until_id
|
||||
return self
|
||||
|
||||
def add_sort_order(self, sort_order: str | None):
|
||||
if sort_order:
|
||||
self.params["sort_order"] = sort_order
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class DMExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: DMEventExpansionFilter):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(filtered_expansions)
|
||||
return self
|
||||
|
||||
def add_event_types(self, event_types: DMEventTypeFilter):
|
||||
if event_types:
|
||||
filtered_types = [
|
||||
name for name, value in event_types.dict().items() if value is True
|
||||
]
|
||||
if filtered_types:
|
||||
self.params["event_types"] = ",".join(filtered_types)
|
||||
return self
|
||||
|
||||
def add_media_fields(self, media_fields: DMMediaFieldFilter):
|
||||
if media_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in media_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["media.fields"] = ",".join(filtered_fields)
|
||||
return self
|
||||
|
||||
def add_tweet_fields(self, tweet_fields: DMTweetFieldFilter):
|
||||
if tweet_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in tweet_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["tweet.fields"] = ",".join(filtered_fields)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(filtered_fields)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
# Specific Builders
|
||||
class TweetSearchBuilder:
|
||||
def __init__(self):
|
||||
self.params: Dict[str, Any] = {"user_auth": False}
|
||||
|
||||
def add_query(self, query: str):
|
||||
if query:
|
||||
self.params["query"] = query
|
||||
return self
|
||||
|
||||
def add_pagination(self, max_results: int, pagination: str | None):
|
||||
if max_results:
|
||||
self.params["max_results"] = max_results
|
||||
if pagination:
|
||||
self.params["pagination_token"] = pagination
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class TweetPostBuilder:
|
||||
def __init__(self):
|
||||
self.params: Dict[str, Any] = {"user_auth": False}
|
||||
|
||||
def add_text(self, text: str | None):
|
||||
if text:
|
||||
self.params["text"] = text
|
||||
return self
|
||||
|
||||
def add_media(self, media_ids: list, tagged_user_ids: list):
|
||||
if media_ids:
|
||||
self.params["media_ids"] = media_ids
|
||||
if tagged_user_ids:
|
||||
self.params["media_tagged_user_ids"] = tagged_user_ids
|
||||
return self
|
||||
|
||||
def add_deep_link(self, link: str):
|
||||
if link:
|
||||
self.params["direct_message_deep_link"] = link
|
||||
return self
|
||||
|
||||
def add_super_followers(self, for_super_followers: bool):
|
||||
if for_super_followers:
|
||||
self.params["for_super_followers_only"] = for_super_followers
|
||||
return self
|
||||
|
||||
def add_place(self, place_id: str):
|
||||
if place_id:
|
||||
self.params["place_id"] = place_id
|
||||
return self
|
||||
|
||||
def add_poll_options(self, poll_options: list):
|
||||
if poll_options:
|
||||
self.params["poll_options"] = poll_options
|
||||
return self
|
||||
|
||||
def add_poll_duration(self, poll_duration_minutes: int):
|
||||
if poll_duration_minutes:
|
||||
self.params["poll_duration_minutes"] = poll_duration_minutes
|
||||
return self
|
||||
|
||||
def add_quote(self, quote_id: str):
|
||||
if quote_id:
|
||||
self.params["quote_tweet_id"] = quote_id
|
||||
return self
|
||||
|
||||
def add_reply_settings(
|
||||
self,
|
||||
exclude_user_ids: list,
|
||||
reply_to_id: str,
|
||||
settings: TweetReplySettingsFilter,
|
||||
):
|
||||
if exclude_user_ids:
|
||||
self.params["exclude_reply_user_ids"] = exclude_user_ids
|
||||
if reply_to_id:
|
||||
self.params["in_reply_to_tweet_id"] = reply_to_id
|
||||
if settings.All_Users:
|
||||
self.params["reply_settings"] = None
|
||||
elif settings.Following_Users_Only:
|
||||
self.params["reply_settings"] = "following"
|
||||
elif settings.Mentioned_Users_Only:
|
||||
self.params["reply_settings"] = "mentionedUsers"
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class TweetGetsBuilder:
|
||||
def __init__(self):
|
||||
self.params: Dict[str, Any] = {"user_auth": False}
|
||||
|
||||
def add_id(self, tweet_id: list[str]):
|
||||
self.params["id"] = tweet_id
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
|
@ -0,0 +1,234 @@
|
|||
# -------------- Tweets -----------------
|
||||
|
||||
# Tweet Expansions
|
||||
EXPANSION_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Poll_IDs": "attachments.poll_ids",
|
||||
"Media_Keys": "attachments.media_keys",
|
||||
"Author_User_ID": "author_id",
|
||||
"Edit_History_Tweet_IDs": "edit_history_tweet_ids",
|
||||
"Mentioned_Usernames": "entities.mentions.username",
|
||||
"Place_ID": "geo.place_id",
|
||||
"Reply_To_User_ID": "in_reply_to_user_id",
|
||||
"Referenced_Tweet_ID": "referenced_tweets.id",
|
||||
"Referenced_Tweet_Author_ID": "referenced_tweets.id.author_id",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_expansion(frontend_key: str) -> str:
|
||||
result = EXPANSION_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid expansion key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetReplySettings
|
||||
REPLY_SETTINGS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Mentioned_Users_Only": "mentionedUsers",
|
||||
"Following_Users_Only": "following",
|
||||
"All_Users": "all",
|
||||
}
|
||||
|
||||
|
||||
# TweetUserFields
|
||||
def get_backend_reply_setting(frontend_key: str) -> str:
|
||||
result = REPLY_SETTINGS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid reply setting key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
USER_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Account_Creation_Date": "created_at",
|
||||
"User_Bio": "description",
|
||||
"User_Entities": "entities",
|
||||
"User_ID": "id",
|
||||
"User_Location": "location",
|
||||
"Latest_Tweet_ID": "most_recent_tweet_id",
|
||||
"Display_Name": "name",
|
||||
"Pinned_Tweet_ID": "pinned_tweet_id",
|
||||
"Profile_Picture_URL": "profile_image_url",
|
||||
"Is_Protected_Account": "protected",
|
||||
"Account_Statistics": "public_metrics",
|
||||
"Profile_URL": "url",
|
||||
"Username": "username",
|
||||
"Is_Verified": "verified",
|
||||
"Verification_Type": "verified_type",
|
||||
"Content_Withholding_Info": "withheld",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_user_field(frontend_key: str) -> str:
|
||||
result = USER_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid user field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetFields
|
||||
FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Tweet_Attachments": "attachments",
|
||||
"Author_ID": "author_id",
|
||||
"Context_Annotations": "context_annotations",
|
||||
"Conversation_ID": "conversation_id",
|
||||
"Creation_Time": "created_at",
|
||||
"Edit_Controls": "edit_controls",
|
||||
"Tweet_Entities": "entities",
|
||||
"Geographic_Location": "geo",
|
||||
"Tweet_ID": "id",
|
||||
"Reply_To_User_ID": "in_reply_to_user_id",
|
||||
"Language": "lang",
|
||||
"Public_Metrics": "public_metrics",
|
||||
"Sensitive_Content_Flag": "possibly_sensitive",
|
||||
"Referenced_Tweets": "referenced_tweets",
|
||||
"Reply_Settings": "reply_settings",
|
||||
"Tweet_Source": "source",
|
||||
"Tweet_Text": "text",
|
||||
"Withheld_Content": "withheld",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_field(frontend_key: str) -> str:
|
||||
result = FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetPollFields
|
||||
POLL_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Duration_Minutes": "duration_minutes",
|
||||
"End_DateTime": "end_datetime",
|
||||
"Poll_ID": "id",
|
||||
"Poll_Options": "options",
|
||||
"Voting_Status": "voting_status",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_poll_field(frontend_key: str) -> str:
|
||||
result = POLL_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid poll field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
PLACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Contained_Within_Places": "contained_within",
|
||||
"Country": "country",
|
||||
"Country_Code": "country_code",
|
||||
"Full_Location_Name": "full_name",
|
||||
"Geographic_Coordinates": "geo",
|
||||
"Place_ID": "id",
|
||||
"Place_Name": "name",
|
||||
"Place_Type": "place_type",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_place_field(frontend_key: str) -> str:
|
||||
result = PLACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid place field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetMediaFields
|
||||
MEDIA_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Duration_in_Milliseconds": "duration_ms",
|
||||
"Height": "height",
|
||||
"Media_Key": "media_key",
|
||||
"Preview_Image_URL": "preview_image_url",
|
||||
"Media_Type": "type",
|
||||
"Media_URL": "url",
|
||||
"Width": "width",
|
||||
"Public_Metrics": "public_metrics",
|
||||
"Non_Public_Metrics": "non_public_metrics",
|
||||
"Organic_Metrics": "organic_metrics",
|
||||
"Promoted_Metrics": "promoted_metrics",
|
||||
"Alternative_Text": "alt_text",
|
||||
"Media_Variants": "variants",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_media_field(frontend_key: str) -> str:
|
||||
result = MEDIA_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid media field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# -------------- Spaces -----------------
|
||||
|
||||
# SpaceExpansions
|
||||
EXPANSION_FRONTEND_TO_BACKEND_MAPPING_SPACE = {
|
||||
"Invited_Users": "invited_user_ids",
|
||||
"Speakers": "speaker_ids",
|
||||
"Creator": "creator_id",
|
||||
"Hosts": "host_ids",
|
||||
"Topics": "topic_ids",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_space_expansion(frontend_key: str) -> str:
|
||||
result = EXPANSION_FRONTEND_TO_BACKEND_MAPPING_SPACE.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid expansion key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# SpaceFields
|
||||
SPACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Space_ID": "id",
|
||||
"Space_State": "state",
|
||||
"Creation_Time": "created_at",
|
||||
"End_Time": "ended_at",
|
||||
"Host_User_IDs": "host_ids",
|
||||
"Language": "lang",
|
||||
"Is_Ticketed": "is_ticketed",
|
||||
"Invited_User_IDs": "invited_user_ids",
|
||||
"Participant_Count": "participant_count",
|
||||
"Subscriber_Count": "subscriber_count",
|
||||
"Scheduled_Start_Time": "scheduled_start",
|
||||
"Speaker_User_IDs": "speaker_ids",
|
||||
"Start_Time": "started_at",
|
||||
"Space_Title": "title",
|
||||
"Topic_IDs": "topic_ids",
|
||||
"Last_Updated_Time": "updated_at",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_space_field(frontend_key: str) -> str:
|
||||
result = SPACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid space field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# -------------- List Expansions -----------------
|
||||
|
||||
# ListExpansions
|
||||
LIST_EXPANSION_FRONTEND_TO_BACKEND_MAPPING = {"List_Owner_ID": "owner_id"}
|
||||
|
||||
|
||||
def get_backend_list_expansion(frontend_key: str) -> str:
|
||||
result = LIST_EXPANSION_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid list expansion key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
LIST_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"List_ID": "id",
|
||||
"List_Name": "name",
|
||||
"Creation_Date": "created_at",
|
||||
"Description": "description",
|
||||
"Follower_Count": "follower_count",
|
||||
"Member_Count": "member_count",
|
||||
"Is_Private": "private",
|
||||
"Owner_ID": "owner_id",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_list_field(frontend_key: str) -> str:
|
||||
result = LIST_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid list field key: {frontend_key}")
|
||||
return result
|
|
@ -0,0 +1,76 @@
|
|||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseSerializer:
|
||||
@staticmethod
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Helper method to serialize individual values"""
|
||||
if hasattr(value, "data"):
|
||||
return value.data
|
||||
return value
|
||||
|
||||
|
||||
class IncludesSerializer(BaseSerializer):
|
||||
@classmethod
|
||||
def serialize(cls, includes: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serializes the includes dictionary"""
|
||||
if not includes:
|
||||
return {}
|
||||
|
||||
serialized_includes = {}
|
||||
for key, value in includes.items():
|
||||
if isinstance(value, list):
|
||||
serialized_includes[key] = [
|
||||
cls._serialize_value(item) for item in value
|
||||
]
|
||||
else:
|
||||
serialized_includes[key] = cls._serialize_value(value)
|
||||
|
||||
return serialized_includes
|
||||
|
||||
|
||||
class ResponseDataSerializer(BaseSerializer):
|
||||
@classmethod
|
||||
def serialize_dict(cls, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serializes a single dictionary item"""
|
||||
serialized_item = {}
|
||||
|
||||
if hasattr(item, "__dict__"):
|
||||
items = item.__dict__.items()
|
||||
else:
|
||||
items = item.items()
|
||||
|
||||
for key, value in items:
|
||||
if isinstance(value, list):
|
||||
serialized_item[key] = [
|
||||
cls._serialize_value(sub_item) for sub_item in value
|
||||
]
|
||||
else:
|
||||
serialized_item[key] = cls._serialize_value(value)
|
||||
|
||||
return serialized_item
|
||||
|
||||
@classmethod
|
||||
def serialize_list(cls, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Serializes a list of dictionary items"""
|
||||
return [cls.serialize_dict(item) for item in data]
|
||||
|
||||
|
||||
class ResponseSerializer:
|
||||
@classmethod
|
||||
def serialize(cls, response) -> Dict[str, Any]:
|
||||
"""Main serializer that handles both data and includes"""
|
||||
result = {"data": None, "included": {}}
|
||||
|
||||
# Handle response.data
|
||||
if response.data:
|
||||
if isinstance(response.data, list):
|
||||
result["data"] = ResponseDataSerializer.serialize_list(response.data)
|
||||
else:
|
||||
result["data"] = ResponseDataSerializer.serialize_dict(response.data)
|
||||
|
||||
# Handle includes
|
||||
if hasattr(response, "includes") and response.includes:
|
||||
result["included"] = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
return result
|
|
@ -0,0 +1,443 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# -------------- Tweets -----------------
|
||||
|
||||
|
||||
class TweetReplySettingsFilter(BaseModel):
|
||||
Mentioned_Users_Only: bool = False
|
||||
Following_Users_Only: bool = False
|
||||
All_Users: bool = False
|
||||
|
||||
|
||||
class TweetUserFieldsFilter(BaseModel):
|
||||
Account_Creation_Date: bool = False
|
||||
User_Bio: bool = False
|
||||
User_Entities: bool = False
|
||||
User_ID: bool = False
|
||||
User_Location: bool = False
|
||||
Latest_Tweet_ID: bool = False
|
||||
Display_Name: bool = False
|
||||
Pinned_Tweet_ID: bool = False
|
||||
Profile_Picture_URL: bool = False
|
||||
Is_Protected_Account: bool = False
|
||||
Account_Statistics: bool = False
|
||||
Profile_URL: bool = False
|
||||
Username: bool = False
|
||||
Is_Verified: bool = False
|
||||
Verification_Type: bool = False
|
||||
Content_Withholding_Info: bool = False
|
||||
|
||||
|
||||
class TweetFieldsFilter(BaseModel):
|
||||
Tweet_Attachments: bool = False
|
||||
Author_ID: bool = False
|
||||
Context_Annotations: bool = False
|
||||
Conversation_ID: bool = False
|
||||
Creation_Time: bool = False
|
||||
Edit_Controls: bool = False
|
||||
Tweet_Entities: bool = False
|
||||
Geographic_Location: bool = False
|
||||
Tweet_ID: bool = False
|
||||
Reply_To_User_ID: bool = False
|
||||
Language: bool = False
|
||||
Public_Metrics: bool = False
|
||||
Sensitive_Content_Flag: bool = False
|
||||
Referenced_Tweets: bool = False
|
||||
Reply_Settings: bool = False
|
||||
Tweet_Source: bool = False
|
||||
Tweet_Text: bool = False
|
||||
Withheld_Content: bool = False
|
||||
|
||||
|
||||
class PersonalTweetFieldsFilter(BaseModel):
|
||||
attachments: bool = False
|
||||
author_id: bool = False
|
||||
context_annotations: bool = False
|
||||
conversation_id: bool = False
|
||||
created_at: bool = False
|
||||
edit_controls: bool = False
|
||||
entities: bool = False
|
||||
geo: bool = False
|
||||
id: bool = False
|
||||
in_reply_to_user_id: bool = False
|
||||
lang: bool = False
|
||||
non_public_metrics: bool = False
|
||||
public_metrics: bool = False
|
||||
organic_metrics: bool = False
|
||||
promoted_metrics: bool = False
|
||||
possibly_sensitive: bool = False
|
||||
referenced_tweets: bool = False
|
||||
reply_settings: bool = False
|
||||
source: bool = False
|
||||
text: bool = False
|
||||
withheld: bool = False
|
||||
|
||||
|
||||
class TweetPollFieldsFilter(BaseModel):
|
||||
Duration_Minutes: bool = False
|
||||
End_DateTime: bool = False
|
||||
Poll_ID: bool = False
|
||||
Poll_Options: bool = False
|
||||
Voting_Status: bool = False
|
||||
|
||||
|
||||
class TweetPlaceFieldsFilter(BaseModel):
|
||||
Contained_Within_Places: bool = False
|
||||
Country: bool = False
|
||||
Country_Code: bool = False
|
||||
Full_Location_Name: bool = False
|
||||
Geographic_Coordinates: bool = False
|
||||
Place_ID: bool = False
|
||||
Place_Name: bool = False
|
||||
Place_Type: bool = False
|
||||
|
||||
|
||||
class TweetMediaFieldsFilter(BaseModel):
|
||||
Duration_in_Milliseconds: bool = False
|
||||
Height: bool = False
|
||||
Media_Key: bool = False
|
||||
Preview_Image_URL: bool = False
|
||||
Media_Type: bool = False
|
||||
Media_URL: bool = False
|
||||
Width: bool = False
|
||||
Public_Metrics: bool = False
|
||||
Non_Public_Metrics: bool = False
|
||||
Organic_Metrics: bool = False
|
||||
Promoted_Metrics: bool = False
|
||||
Alternative_Text: bool = False
|
||||
Media_Variants: bool = False
|
||||
|
||||
|
||||
class ExpansionFilter(BaseModel):
|
||||
Poll_IDs: bool = False
|
||||
Media_Keys: bool = False
|
||||
Author_User_ID: bool = False
|
||||
Edit_History_Tweet_IDs: bool = False
|
||||
Mentioned_Usernames: bool = False
|
||||
Place_ID: bool = False
|
||||
Reply_To_User_ID: bool = False
|
||||
Referenced_Tweet_ID: bool = False
|
||||
Referenced_Tweet_Author_ID: bool = False
|
||||
|
||||
|
||||
class TweetExcludesFilter(BaseModel):
|
||||
retweets: bool = False
|
||||
replies: bool = False
|
||||
|
||||
|
||||
# -------------- Users -----------------
|
||||
|
||||
|
||||
class UserExpansionsFilter(BaseModel):
|
||||
pinned_tweet_id: bool = False
|
||||
|
||||
|
||||
# -------------- DM's' -----------------
|
||||
|
||||
|
||||
class DMEventFieldFilter(BaseModel):
|
||||
id: bool = False
|
||||
text: bool = False
|
||||
event_type: bool = False
|
||||
created_at: bool = False
|
||||
dm_conversation_id: bool = False
|
||||
sender_id: bool = False
|
||||
participant_ids: bool = False
|
||||
referenced_tweets: bool = False
|
||||
attachments: bool = False
|
||||
|
||||
|
||||
class DMEventTypeFilter(BaseModel):
|
||||
MessageCreate: bool = False
|
||||
ParticipantsJoin: bool = False
|
||||
ParticipantsLeave: bool = False
|
||||
|
||||
|
||||
class DMEventExpansionFilter(BaseModel):
|
||||
attachments_media_keys: bool = False
|
||||
referenced_tweets_id: bool = False
|
||||
sender_id: bool = False
|
||||
participant_ids: bool = False
|
||||
|
||||
|
||||
class DMMediaFieldFilter(BaseModel):
|
||||
duration_ms: bool = False
|
||||
height: bool = False
|
||||
media_key: bool = False
|
||||
preview_image_url: bool = False
|
||||
type: bool = False
|
||||
url: bool = False
|
||||
width: bool = False
|
||||
public_metrics: bool = False
|
||||
alt_text: bool = False
|
||||
variants: bool = False
|
||||
|
||||
|
||||
class DMTweetFieldFilter(BaseModel):
|
||||
attachments: bool = False
|
||||
author_id: bool = False
|
||||
context_annotations: bool = False
|
||||
conversation_id: bool = False
|
||||
created_at: bool = False
|
||||
edit_controls: bool = False
|
||||
entities: bool = False
|
||||
geo: bool = False
|
||||
id: bool = False
|
||||
in_reply_to_user_id: bool = False
|
||||
lang: bool = False
|
||||
public_metrics: bool = False
|
||||
possibly_sensitive: bool = False
|
||||
referenced_tweets: bool = False
|
||||
reply_settings: bool = False
|
||||
source: bool = False
|
||||
text: bool = False
|
||||
withheld: bool = False
|
||||
|
||||
|
||||
# -------------- Spaces -----------------
|
||||
|
||||
|
||||
class SpaceExpansionsFilter(BaseModel):
|
||||
Invited_Users: bool = False
|
||||
Speakers: bool = False
|
||||
Creator: bool = False
|
||||
Hosts: bool = False
|
||||
Topics: bool = False
|
||||
|
||||
|
||||
class SpaceFieldsFilter(BaseModel):
|
||||
Space_ID: bool = False
|
||||
Space_State: bool = False
|
||||
Creation_Time: bool = False
|
||||
End_Time: bool = False
|
||||
Host_User_IDs: bool = False
|
||||
Language: bool = False
|
||||
Is_Ticketed: bool = False
|
||||
Invited_User_IDs: bool = False
|
||||
Participant_Count: bool = False
|
||||
Subscriber_Count: bool = False
|
||||
Scheduled_Start_Time: bool = False
|
||||
Speaker_User_IDs: bool = False
|
||||
Start_Time: bool = False
|
||||
Space_Title: bool = False
|
||||
Topic_IDs: bool = False
|
||||
Last_Updated_Time: bool = False
|
||||
|
||||
|
||||
class SpaceStatesFilter(str, Enum):
|
||||
live = "live"
|
||||
scheduled = "scheduled"
|
||||
all = "all"
|
||||
|
||||
|
||||
# -------------- List Expansions -----------------
|
||||
|
||||
|
||||
class ListExpansionsFilter(BaseModel):
|
||||
List_Owner_ID: bool = False
|
||||
|
||||
|
||||
class ListFieldsFilter(BaseModel):
|
||||
List_ID: bool = False
|
||||
List_Name: bool = False
|
||||
Creation_Date: bool = False
|
||||
Description: bool = False
|
||||
Follower_Count: bool = False
|
||||
Member_Count: bool = False
|
||||
Is_Private: bool = False
|
||||
Owner_ID: bool = False
|
||||
|
||||
|
||||
# --------- [Input Types] -------------
|
||||
class TweetExpansionInputs(BlockSchema):
|
||||
|
||||
expansions: ExpansionFilter | None = SchemaField(
|
||||
description="Choose what extra information you want to get with your tweets. For example:\n- Select 'Media_Keys' to get media details\n- Select 'Author_User_ID' to get user information\n- Select 'Place_ID' to get location details",
|
||||
placeholder="Pick the extra information you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
media_fields: TweetMediaFieldsFilter | None = SchemaField(
|
||||
description="Select what media information you want to see (images, videos, etc). To use this, you must first select 'Media_Keys' in the expansions above.",
|
||||
placeholder="Choose what media details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
place_fields: TweetPlaceFieldsFilter | None = SchemaField(
|
||||
description="Select what location information you want to see (country, coordinates, etc). To use this, you must first select 'Place_ID' in the expansions above.",
|
||||
placeholder="Choose what location details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
poll_fields: TweetPollFieldsFilter | None = SchemaField(
|
||||
description="Select what poll information you want to see (options, voting status, etc). To use this, you must first select 'Poll_IDs' in the expansions above.",
|
||||
placeholder="Choose what poll details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tweet_fields: TweetFieldsFilter | None = SchemaField(
|
||||
description="Select what tweet information you want to see. For referenced tweets (like retweets), select 'Referenced_Tweet_ID' in the expansions above.",
|
||||
placeholder="Choose what tweet details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Select what user information you want to see. To use this, you must first select one of these in expansions above:\n- 'Author_User_ID' for tweet authors\n- 'Mentioned_Usernames' for mentioned users\n- 'Reply_To_User_ID' for users being replied to\n- 'Referenced_Tweet_Author_ID' for authors of referenced tweets",
|
||||
placeholder="Choose what user details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class DMEventExpansionInputs(BlockSchema):
|
||||
expansions: DMEventExpansionFilter | None = SchemaField(
|
||||
description="Select expansions to include related data objects in the 'includes' section.",
|
||||
placeholder="Enter expansions",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
event_types: DMEventTypeFilter | None = SchemaField(
|
||||
description="Select DM event types to include in the response.",
|
||||
placeholder="Enter event types",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
media_fields: DMMediaFieldFilter | None = SchemaField(
|
||||
description="Select media fields to include in the response (requires expansions=attachments.media_keys).",
|
||||
placeholder="Enter media fields",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tweet_fields: DMTweetFieldFilter | None = SchemaField(
|
||||
description="Select tweet fields to include in the response (requires expansions=referenced_tweets.id).",
|
||||
placeholder="Enter tweet fields",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Select user fields to include in the response (requires expansions=sender_id or participant_ids).",
|
||||
placeholder="Enter user fields",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class UserExpansionInputs(BlockSchema):
|
||||
expansions: UserExpansionsFilter | None = SchemaField(
|
||||
description="Choose what extra information you want to get with user data. Currently only 'pinned_tweet_id' is available to see a user's pinned tweet.",
|
||||
placeholder="Select extra user information to include",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tweet_fields: TweetFieldsFilter | None = SchemaField(
|
||||
description="Select what tweet information you want to see in pinned tweets. This only works if you select 'pinned_tweet_id' in expansions above.",
|
||||
placeholder="Choose what details to see in pinned tweets",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Select what user information you want to see, like username, bio, profile picture, etc.",
|
||||
placeholder="Choose what user details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class SpaceExpansionInputs(BlockSchema):
|
||||
expansions: SpaceExpansionsFilter | None = SchemaField(
|
||||
description="Choose additional information you want to get with your Twitter Spaces:\n- Select 'Invited_Users' to see who was invited\n- Select 'Speakers' to see who can speak\n- Select 'Creator' to get details about who made the Space\n- Select 'Hosts' to see who's hosting\n- Select 'Topics' to see Space topics",
|
||||
placeholder="Pick what extra information you want to see about the Space",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
space_fields: SpaceFieldsFilter | None = SchemaField(
|
||||
description="Choose what Space details you want to see, such as:\n- Title\n- Start/End times\n- Number of participants\n- Language\n- State (live/scheduled)\n- And more",
|
||||
placeholder="Choose what Space information you want to get",
|
||||
default=SpaceFieldsFilter(Space_Title=True, Host_User_IDs=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Choose what user information you want to see. This works when you select any of these in expansions above:\n- 'Creator' for Space creator details\n- 'Hosts' for host information\n- 'Speakers' for speaker details\n- 'Invited_Users' for invited user information",
|
||||
placeholder="Pick what details you want to see about the users",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class ListExpansionInputs(BlockSchema):
|
||||
expansions: ListExpansionsFilter | None = SchemaField(
|
||||
description="Choose what extra information you want to get with your Twitter Lists:\n- Select 'List_Owner_ID' to get details about who owns the list\n\nThis will let you see more details about the list owner when you also select user fields below.",
|
||||
placeholder="Pick what extra list information you want to see",
|
||||
default=ListExpansionsFilter(List_Owner_ID=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Choose what information you want to see about list owners. This only works when you select 'List_Owner_ID' in expansions above.\n\nYou can see things like:\n- Their username\n- Profile picture\n- Account details\n- And more",
|
||||
placeholder="Select what details you want to see about list owners",
|
||||
default=TweetUserFieldsFilter(User_ID=True, Username=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
list_fields: ListFieldsFilter | None = SchemaField(
|
||||
description="Choose what information you want to see about the Twitter Lists themselves, such as:\n- List name\n- Description\n- Number of followers\n- Number of members\n- Whether it's private\n- Creation date\n- And more",
|
||||
placeholder="Pick what list details you want to see",
|
||||
default=ListFieldsFilter(Owner_ID=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class TweetTimeWindowInputs(BlockSchema):
|
||||
start_time: datetime | None = SchemaField(
|
||||
description="Start time in YYYY-MM-DDTHH:mm:ssZ format",
|
||||
placeholder="Enter start time",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
end_time: datetime | None = SchemaField(
|
||||
description="End time in YYYY-MM-DDTHH:mm:ssZ format",
|
||||
placeholder="Enter end time",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
since_id: str | None = SchemaField(
|
||||
description="Returns results with Tweet ID greater than this (more recent than), we give priority to since_id over start_time",
|
||||
placeholder="Enter since ID",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
until_id: str | None = SchemaField(
|
||||
description="Returns results with Tweet ID less than this (that is, older than), and used with since_id",
|
||||
placeholder="Enter until ID",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
sort_order: str | None = SchemaField(
|
||||
description="Order of returned tweets (recency or relevancy)",
|
||||
placeholder="Enter sort order",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
|
@ -0,0 +1,201 @@
|
|||
# Todo : Add new Type support
|
||||
|
||||
# from typing import cast
|
||||
# import tweepy
|
||||
# from tweepy.client import Response
|
||||
|
||||
# from backend.blocks.twitter._serializer import IncludesSerializer, ResponseDataSerializer
|
||||
# from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
# from backend.data.model import SchemaField
|
||||
# from backend.blocks.twitter._builders import DMExpansionsBuilder
|
||||
# from backend.blocks.twitter._types import DMEventExpansion, DMEventExpansionInputs, DMEventType, DMMediaField, DMTweetField, TweetUserFields
|
||||
# from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
# from backend.blocks.twitter._auth import (
|
||||
# TEST_CREDENTIALS,
|
||||
# TEST_CREDENTIALS_INPUT,
|
||||
# TwitterCredentials,
|
||||
# TwitterCredentialsField,
|
||||
# TwitterCredentialsInput,
|
||||
# )
|
||||
|
||||
# Require Pro or Enterprise plan [Manual Testing Required]
|
||||
# class TwitterGetDMEventsBlock(Block):
|
||||
# """
|
||||
# Gets a list of Direct Message events for the authenticated user
|
||||
# """
|
||||
|
||||
# class Input(DMEventExpansionInputs):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["dm.read", "offline.access", "user.read", "tweet.read"]
|
||||
# )
|
||||
|
||||
# dm_conversation_id: str = SchemaField(
|
||||
# description="The ID of the Direct Message conversation",
|
||||
# placeholder="Enter conversation ID",
|
||||
# required=True
|
||||
# )
|
||||
|
||||
# max_results: int = SchemaField(
|
||||
# description="Maximum number of results to return (1-100)",
|
||||
# placeholder="Enter max results",
|
||||
# advanced=True,
|
||||
# default=10,
|
||||
# )
|
||||
|
||||
# pagination_token: str = SchemaField(
|
||||
# description="Token for pagination",
|
||||
# placeholder="Enter pagination token",
|
||||
# advanced=True,
|
||||
# default=""
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# # Common outputs
|
||||
# event_ids: list[str] = SchemaField(description="DM Event IDs")
|
||||
# event_texts: list[str] = SchemaField(description="DM Event text contents")
|
||||
# event_types: list[str] = SchemaField(description="Types of DM events")
|
||||
# next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# # Complete outputs
|
||||
# data: list[dict] = SchemaField(description="Complete DM events data")
|
||||
# included: dict = SchemaField(description="Additional data requested via expansions")
|
||||
# meta: dict = SchemaField(description="Metadata about the response")
|
||||
# error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="dc37a6d4-a62e-11ef-a3a5-03061375737b",
|
||||
# description="This block retrieves Direct Message events for the authenticated user.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterGetDMEventsBlock.Input,
|
||||
# output_schema=TwitterGetDMEventsBlock.Output,
|
||||
# test_input={
|
||||
# "dm_conversation_id": "1234567890",
|
||||
# "max_results": 10,
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "expansions": [],
|
||||
# "event_types": [],
|
||||
# "media_fields": [],
|
||||
# "tweet_fields": [],
|
||||
# "user_fields": []
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("event_ids", ["1346889436626259968"]),
|
||||
# ("event_texts", ["Hello just you..."]),
|
||||
# ("event_types", ["MessageCreate"]),
|
||||
# ("next_token", None),
|
||||
# ("data", [{"id": "1346889436626259968", "text": "Hello just you...", "event_type": "MessageCreate"}]),
|
||||
# ("included", {}),
|
||||
# ("meta", {}),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "get_dm_events": lambda *args, **kwargs: (
|
||||
# [{"id": "1346889436626259968", "text": "Hello just you...", "event_type": "MessageCreate"}],
|
||||
# {},
|
||||
# {},
|
||||
# ["1346889436626259968"],
|
||||
# ["Hello just you..."],
|
||||
# ["MessageCreate"],
|
||||
# None
|
||||
# )
|
||||
# }
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def get_dm_events(
|
||||
# credentials: TwitterCredentials,
|
||||
# dm_conversation_id: str,
|
||||
# max_results: int,
|
||||
# pagination_token: str,
|
||||
# expansions: list[DMEventExpansion],
|
||||
# event_types: list[DMEventType],
|
||||
# media_fields: list[DMMediaField],
|
||||
# tweet_fields: list[DMTweetField],
|
||||
# user_fields: list[TweetUserFields]
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value()
|
||||
# )
|
||||
|
||||
# params = {
|
||||
# "dm_conversation_id": dm_conversation_id,
|
||||
# "max_results": max_results,
|
||||
# "pagination_token": None if pagination_token == "" else pagination_token,
|
||||
# "user_auth": False
|
||||
# }
|
||||
|
||||
# params = (DMExpansionsBuilder(params)
|
||||
# .add_expansions(expansions)
|
||||
# .add_event_types(event_types)
|
||||
# .add_media_fields(media_fields)
|
||||
# .add_tweet_fields(tweet_fields)
|
||||
# .add_user_fields(user_fields)
|
||||
# .build())
|
||||
|
||||
# response = cast(Response, client.get_direct_message_events(**params))
|
||||
|
||||
# meta = {}
|
||||
# event_ids = []
|
||||
# event_texts = []
|
||||
# event_types = []
|
||||
# next_token = None
|
||||
|
||||
# if response.meta:
|
||||
# meta = response.meta
|
||||
# next_token = meta.get("next_token")
|
||||
|
||||
# included = IncludesSerializer.serialize(response.includes)
|
||||
# data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
# if response.data:
|
||||
# event_ids = [str(item.id) for item in response.data]
|
||||
# event_texts = [item.text if hasattr(item, "text") else None for item in response.data]
|
||||
# event_types = [item.event_type for item in response.data]
|
||||
|
||||
# return data, included, meta, event_ids, event_texts, event_types, next_token
|
||||
|
||||
# raise Exception("No DM events found")
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# event_data, included, meta, event_ids, event_texts, event_types, next_token = self.get_dm_events(
|
||||
# credentials,
|
||||
# input_data.dm_conversation_id,
|
||||
# input_data.max_results,
|
||||
# input_data.pagination_token,
|
||||
# input_data.expansions,
|
||||
# input_data.event_types,
|
||||
# input_data.media_fields,
|
||||
# input_data.tweet_fields,
|
||||
# input_data.user_fields
|
||||
# )
|
||||
|
||||
# if event_ids:
|
||||
# yield "event_ids", event_ids
|
||||
# if event_texts:
|
||||
# yield "event_texts", event_texts
|
||||
# if event_types:
|
||||
# yield "event_types", event_types
|
||||
# if next_token:
|
||||
# yield "next_token", next_token
|
||||
# if event_data:
|
||||
# yield "data", event_data
|
||||
# if included:
|
||||
# yield "included", included
|
||||
# if meta:
|
||||
# yield "meta", meta
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,260 @@
|
|||
# Todo : Add new Type support
|
||||
|
||||
# from typing import cast
|
||||
|
||||
# import tweepy
|
||||
# from tweepy.client import Response
|
||||
|
||||
# from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
# from backend.data.model import SchemaField
|
||||
# from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
# from backend.blocks.twitter._auth import (
|
||||
# TEST_CREDENTIALS,
|
||||
# TEST_CREDENTIALS_INPUT,
|
||||
# TwitterCredentials,
|
||||
# TwitterCredentialsField,
|
||||
# TwitterCredentialsInput,
|
||||
# )
|
||||
|
||||
# Pro and Enterprise plan [Manual Testing Required]
|
||||
# class TwitterSendDirectMessageBlock(Block):
|
||||
# """
|
||||
# Sends a direct message to a Twitter user
|
||||
# """
|
||||
|
||||
# class Input(BlockSchema):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["offline.access", "direct_messages.write"]
|
||||
# )
|
||||
|
||||
# participant_id: str = SchemaField(
|
||||
# description="The User ID of the account to send DM to",
|
||||
# placeholder="Enter recipient user ID",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# dm_conversation_id: str = SchemaField(
|
||||
# description="The conversation ID to send message to",
|
||||
# placeholder="Enter conversation ID",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# text: str = SchemaField(
|
||||
# description="Text of the Direct Message (up to 10,000 characters)",
|
||||
# placeholder="Enter message text",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# media_id: str = SchemaField(
|
||||
# description="Media ID to attach to the message",
|
||||
# placeholder="Enter media ID",
|
||||
# default=""
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# dm_event_id: str = SchemaField(description="ID of the sent direct message")
|
||||
# dm_conversation_id_: str = SchemaField(description="ID of the conversation")
|
||||
# error: str = SchemaField(description="Error message if sending failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="f32f2786-a62e-11ef-a93d-a3ef199dde7f",
|
||||
# description="This block sends a direct message to a specified Twitter user.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterSendDirectMessageBlock.Input,
|
||||
# output_schema=TwitterSendDirectMessageBlock.Output,
|
||||
# test_input={
|
||||
# "participant_id": "783214",
|
||||
# "dm_conversation_id": "",
|
||||
# "text": "Hello from Twitter API",
|
||||
# "media_id": "",
|
||||
# "credentials": TEST_CREDENTIALS_INPUT
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("dm_event_id", "0987654321"),
|
||||
# ("dm_conversation_id_", "1234567890"),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "send_direct_message": lambda *args, **kwargs: (
|
||||
# "0987654321",
|
||||
# "1234567890"
|
||||
# )
|
||||
# },
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def send_direct_message(
|
||||
# credentials: TwitterCredentials,
|
||||
# participant_id: str,
|
||||
# dm_conversation_id: str,
|
||||
# text: str,
|
||||
# media_id: str
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value()
|
||||
# )
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.create_direct_message(
|
||||
# participant_id=None if participant_id == "" else participant_id,
|
||||
# dm_conversation_id=None if dm_conversation_id == "" else dm_conversation_id,
|
||||
# text=None if text == "" else text,
|
||||
# media_id=None if media_id == "" else media_id,
|
||||
# user_auth=False
|
||||
# )
|
||||
# )
|
||||
|
||||
# if not response.data:
|
||||
# raise Exception("Failed to send direct message")
|
||||
|
||||
# return response.data["dm_event_id"], response.data["dm_conversation_id"]
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# print(f"Unexpected error: {str(e)}")
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# dm_event_id, dm_conversation_id = self.send_direct_message(
|
||||
# credentials,
|
||||
# input_data.participant_id,
|
||||
# input_data.dm_conversation_id,
|
||||
# input_data.text,
|
||||
# input_data.media_id
|
||||
# )
|
||||
# yield "dm_event_id", dm_event_id
|
||||
# yield "dm_conversation_id", dm_conversation_id
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
|
||||
# class TwitterCreateDMConversationBlock(Block):
|
||||
# """
|
||||
# Creates a new group direct message conversation on Twitter
|
||||
# """
|
||||
|
||||
# class Input(BlockSchema):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["offline.access", "dm.write","dm.read","tweet.read","user.read"]
|
||||
# )
|
||||
|
||||
# participant_ids: list[str] = SchemaField(
|
||||
# description="Array of User IDs to create conversation with (max 50)",
|
||||
# placeholder="Enter participant user IDs",
|
||||
# default=[],
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# text: str = SchemaField(
|
||||
# description="Text of the Direct Message (up to 10,000 characters)",
|
||||
# placeholder="Enter message text",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# media_id: str = SchemaField(
|
||||
# description="Media ID to attach to the message",
|
||||
# placeholder="Enter media ID",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# dm_event_id: str = SchemaField(description="ID of the sent direct message")
|
||||
# dm_conversation_id: str = SchemaField(description="ID of the conversation")
|
||||
# error: str = SchemaField(description="Error message if sending failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="ec11cabc-a62e-11ef-8c0e-3fe37ba2ec92",
|
||||
# description="This block creates a new group DM conversation with specified Twitter users.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterCreateDMConversationBlock.Input,
|
||||
# output_schema=TwitterCreateDMConversationBlock.Output,
|
||||
# test_input={
|
||||
# "participant_ids": ["783214", "2244994945"],
|
||||
# "text": "Hello from Twitter API",
|
||||
# "media_id": "",
|
||||
# "credentials": TEST_CREDENTIALS_INPUT
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("dm_event_id", "0987654321"),
|
||||
# ("dm_conversation_id", "1234567890"),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "create_dm_conversation": lambda *args, **kwargs: (
|
||||
# "0987654321",
|
||||
# "1234567890"
|
||||
# )
|
||||
# },
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def create_dm_conversation(
|
||||
# credentials: TwitterCredentials,
|
||||
# participant_ids: list[str],
|
||||
# text: str,
|
||||
# media_id: str
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value()
|
||||
# )
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.create_direct_message_conversation(
|
||||
# participant_ids=participant_ids,
|
||||
# text=None if text == "" else text,
|
||||
# media_id=None if media_id == "" else media_id,
|
||||
# user_auth=False
|
||||
# )
|
||||
# )
|
||||
|
||||
# if not response.data:
|
||||
# raise Exception("Failed to create DM conversation")
|
||||
|
||||
# return response.data["dm_event_id"], response.data["dm_conversation_id"]
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# print(f"Unexpected error: {str(e)}")
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# dm_event_id, dm_conversation_id = self.create_dm_conversation(
|
||||
# credentials,
|
||||
# input_data.participant_ids,
|
||||
# input_data.text,
|
||||
# input_data.media_id
|
||||
# )
|
||||
# yield "dm_event_id", dm_event_id
|
||||
# yield "dm_conversation_id", dm_conversation_id
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,470 @@
|
|||
# from typing import cast
|
||||
import tweepy
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
|
||||
# from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
# from backend.blocks.twitter._types import TweetFields, TweetUserFields, UserExpansionInputs, UserExpansions
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# from tweepy.client import Response
|
||||
|
||||
|
||||
class TwitterUnfollowListBlock(Block):
|
||||
"""
|
||||
Unfollows a Twitter list for the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["follows.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unfollow",
|
||||
placeholder="Enter list ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unfollow was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1f43310a-a62f-11ef-8276-2b06a1bbae1a",
|
||||
description="This block unfollows a specified Twitter list for the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnfollowListBlock.Input,
|
||||
output_schema=TwitterUnfollowListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unfollow_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unfollow_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unfollow_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unfollow_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterFollowListBlock(Block):
|
||||
"""
|
||||
Follows a Twitter list for the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "list.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to follow",
|
||||
placeholder="Enter list ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the follow was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="03d8acf6-a62f-11ef-b17f-b72b04a09e79",
|
||||
description="This block follows a specified Twitter list for the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterFollowListBlock.Input,
|
||||
output_schema=TwitterFollowListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"follow_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def follow_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.follow_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.follow_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
# Enterprise Level [Need to do Manual testing], There is a high possibility that we might get error in this
|
||||
# Needs Type Input in this
|
||||
|
||||
# class TwitterListGetFollowersBlock(Block):
|
||||
# """
|
||||
# Gets followers of a specified Twitter list
|
||||
# """
|
||||
|
||||
# class Input(UserExpansionInputs):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["tweet.read","users.read", "list.read", "offline.access"]
|
||||
# )
|
||||
|
||||
# list_id: str = SchemaField(
|
||||
# description="The ID of the List to get followers for",
|
||||
# placeholder="Enter list ID",
|
||||
# required=True
|
||||
# )
|
||||
|
||||
# max_results: int = SchemaField(
|
||||
# description="Max number of results per page (1-100)",
|
||||
# placeholder="Enter max results",
|
||||
# default=10,
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# pagination_token: str = SchemaField(
|
||||
# description="Token for pagination",
|
||||
# placeholder="Enter pagination token",
|
||||
# default="",
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# user_ids: list[str] = SchemaField(description="List of user IDs of followers")
|
||||
# usernames: list[str] = SchemaField(description="List of usernames of followers")
|
||||
# next_token: str = SchemaField(description="Token for next page of results")
|
||||
# data: list[dict] = SchemaField(description="Complete follower data")
|
||||
# included: dict = SchemaField(description="Additional data requested via expansions")
|
||||
# meta: dict = SchemaField(description="Metadata about the response")
|
||||
# error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="16b289b4-a62f-11ef-95d4-bb29b849eb99",
|
||||
# description="This block retrieves followers of a specified Twitter list.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterListGetFollowersBlock.Input,
|
||||
# output_schema=TwitterListGetFollowersBlock.Output,
|
||||
# test_input={
|
||||
# "list_id": "123456789",
|
||||
# "max_results": 10,
|
||||
# "pagination_token": None,
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "expansions": [],
|
||||
# "tweet_fields": [],
|
||||
# "user_fields": []
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("user_ids", ["2244994945"]),
|
||||
# ("usernames", ["testuser"]),
|
||||
# ("next_token", None),
|
||||
# ("data", {"followers": [{"id": "2244994945", "username": "testuser"}]}),
|
||||
# ("included", {}),
|
||||
# ("meta", {}),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "get_list_followers": lambda *args, **kwargs: ({
|
||||
# "followers": [{"id": "2244994945", "username": "testuser"}]
|
||||
# }, {}, {}, ["2244994945"], ["testuser"], None)
|
||||
# }
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def get_list_followers(
|
||||
# credentials: TwitterCredentials,
|
||||
# list_id: str,
|
||||
# max_results: int,
|
||||
# pagination_token: str,
|
||||
# expansions: list[UserExpansions],
|
||||
# tweet_fields: list[TweetFields],
|
||||
# user_fields: list[TweetUserFields]
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value(),
|
||||
# )
|
||||
|
||||
# params = {
|
||||
# "id": list_id,
|
||||
# "max_results": max_results,
|
||||
# "pagination_token": None if pagination_token == "" else pagination_token,
|
||||
# "user_auth": False
|
||||
# }
|
||||
|
||||
# params = (UserExpansionsBuilder(params)
|
||||
# .add_expansions(expansions)
|
||||
# .add_tweet_fields(tweet_fields)
|
||||
# .add_user_fields(user_fields)
|
||||
# .build())
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.get_list_followers(**params)
|
||||
# )
|
||||
|
||||
# meta = {}
|
||||
# user_ids = []
|
||||
# usernames = []
|
||||
# next_token = None
|
||||
|
||||
# if response.meta:
|
||||
# meta = response.meta
|
||||
# next_token = meta.get("next_token")
|
||||
|
||||
# included = IncludesSerializer.serialize(response.includes)
|
||||
# data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
# if response.data:
|
||||
# user_ids = [str(item.id) for item in response.data]
|
||||
# usernames = [item.username for item in response.data]
|
||||
|
||||
# return data, included, meta, user_ids, usernames, next_token
|
||||
|
||||
# raise Exception("No followers found")
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# followers_data, included, meta, user_ids, usernames, next_token = self.get_list_followers(
|
||||
# credentials,
|
||||
# input_data.list_id,
|
||||
# input_data.max_results,
|
||||
# input_data.pagination_token,
|
||||
# input_data.expansions,
|
||||
# input_data.tweet_fields,
|
||||
# input_data.user_fields
|
||||
# )
|
||||
|
||||
# if user_ids:
|
||||
# yield "user_ids", user_ids
|
||||
# if usernames:
|
||||
# yield "usernames", usernames
|
||||
# if next_token:
|
||||
# yield "next_token", next_token
|
||||
# if followers_data:
|
||||
# yield "data", followers_data
|
||||
# if included:
|
||||
# yield "included", included
|
||||
# if meta:
|
||||
# yield "meta", meta
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
|
||||
# class TwitterGetFollowedListsBlock(Block):
|
||||
# """
|
||||
# Gets lists followed by a specified Twitter user
|
||||
# """
|
||||
|
||||
# class Input(UserExpansionInputs):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["follows.read", "users.read", "list.read", "offline.access"]
|
||||
# )
|
||||
|
||||
# user_id: str = SchemaField(
|
||||
# description="The user ID whose followed Lists to retrieve",
|
||||
# placeholder="Enter user ID",
|
||||
# required=True
|
||||
# )
|
||||
|
||||
# max_results: int = SchemaField(
|
||||
# description="Max number of results per page (1-100)",
|
||||
# placeholder="Enter max results",
|
||||
# default=10,
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# pagination_token: str = SchemaField(
|
||||
# description="Token for pagination",
|
||||
# placeholder="Enter pagination token",
|
||||
# default="",
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# list_ids: list[str] = SchemaField(description="List of list IDs")
|
||||
# list_names: list[str] = SchemaField(description="List of list names")
|
||||
# data: list[dict] = SchemaField(description="Complete list data")
|
||||
# includes: dict = SchemaField(description="Additional data requested via expansions")
|
||||
# meta: dict = SchemaField(description="Metadata about the response")
|
||||
# next_token: str = SchemaField(description="Token for next page of results")
|
||||
# error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="0e18bbfc-a62f-11ef-94fa-1f1e174b809e",
|
||||
# description="This block retrieves all Lists a specified user follows.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterGetFollowedListsBlock.Input,
|
||||
# output_schema=TwitterGetFollowedListsBlock.Output,
|
||||
# test_input={
|
||||
# "user_id": "123456789",
|
||||
# "max_results": 10,
|
||||
# "pagination_token": None,
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "expansions": [],
|
||||
# "tweet_fields": [],
|
||||
# "user_fields": []
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("list_ids", ["12345"]),
|
||||
# ("list_names", ["Test List"]),
|
||||
# ("data", {"followed_lists": [{"id": "12345", "name": "Test List"}]}),
|
||||
# ("includes", {}),
|
||||
# ("meta", {}),
|
||||
# ("next_token", None),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "get_followed_lists": lambda *args, **kwargs: ({
|
||||
# "followed_lists": [{"id": "12345", "name": "Test List"}]
|
||||
# }, {}, {}, ["12345"], ["Test List"], None)
|
||||
# }
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def get_followed_lists(
|
||||
# credentials: TwitterCredentials,
|
||||
# user_id: str,
|
||||
# max_results: int,
|
||||
# pagination_token: str,
|
||||
# expansions: list[UserExpansions],
|
||||
# tweet_fields: list[TweetFields],
|
||||
# user_fields: list[TweetUserFields]
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value(),
|
||||
# )
|
||||
|
||||
# params = {
|
||||
# "id": user_id,
|
||||
# "max_results": max_results,
|
||||
# "pagination_token": None if pagination_token == "" else pagination_token,
|
||||
# "user_auth": False
|
||||
# }
|
||||
|
||||
# params = (UserExpansionsBuilder(params)
|
||||
# .add_expansions(expansions)
|
||||
# .add_tweet_fields(tweet_fields)
|
||||
# .add_user_fields(user_fields)
|
||||
# .build())
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.get_followed_lists(**params)
|
||||
# )
|
||||
|
||||
# meta = {}
|
||||
# list_ids = []
|
||||
# list_names = []
|
||||
# next_token = None
|
||||
|
||||
# if response.meta:
|
||||
# meta = response.meta
|
||||
# next_token = meta.get("next_token")
|
||||
|
||||
# included = IncludesSerializer.serialize(response.includes)
|
||||
# data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
# if response.data:
|
||||
# list_ids = [str(item.id) for item in response.data]
|
||||
# list_names = [item.name for item in response.data]
|
||||
|
||||
# return data, included, meta, list_ids, list_names, next_token
|
||||
|
||||
# raise Exception("No followed lists found")
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# lists_data, included, meta, list_ids, list_names, next_token = self.get_followed_lists(
|
||||
# credentials,
|
||||
# input_data.user_id,
|
||||
# input_data.max_results,
|
||||
# input_data.pagination_token,
|
||||
# input_data.expansions,
|
||||
# input_data.tweet_fields,
|
||||
# input_data.user_fields
|
||||
# )
|
||||
|
||||
# if list_ids:
|
||||
# yield "list_ids", list_ids
|
||||
# if list_names:
|
||||
# yield "list_names", list_names
|
||||
# if next_token:
|
||||
# yield "next_token", next_token
|
||||
# if lists_data:
|
||||
# yield "data", lists_data
|
||||
# if included:
|
||||
# yield "includes", included
|
||||
# if meta:
|
||||
# yield "meta", meta
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,348 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import ListExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ListExpansionInputs,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetListBlock(Block):
|
||||
"""
|
||||
Gets information about a Twitter List specified by ID
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to lookup",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
id: str = SchemaField(description="ID of the Twitter List")
|
||||
name: str = SchemaField(description="Name of the Twitter List")
|
||||
owner_id: str = SchemaField(description="ID of the List owner")
|
||||
owner_username: str = SchemaField(description="Username of the List owner")
|
||||
|
||||
# Complete outputs
|
||||
data: dict = SchemaField(description="Complete list data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the response")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34ebc80a-a62f-11ef-9c2a-3fcab6c07079",
|
||||
description="This block retrieves information about a specified Twitter List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListBlock.Input,
|
||||
output_schema=TwitterGetListBlock.Output,
|
||||
test_input={
|
||||
"list_id": "84839422",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "84839422"),
|
||||
("name", "Official Twitter Accounts"),
|
||||
("owner_id", "2244994945"),
|
||||
("owner_username", "TwitterAPI"),
|
||||
("data", {"id": "84839422", "name": "Official Twitter Accounts"}),
|
||||
],
|
||||
test_mock={
|
||||
"get_list": lambda *args, **kwargs: (
|
||||
{"id": "84839422", "name": "Official Twitter Accounts"},
|
||||
{},
|
||||
{},
|
||||
"2244994945",
|
||||
"TwitterAPI",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {"id": list_id, "user_auth": False}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list(**params))
|
||||
|
||||
meta = {}
|
||||
owner_id = ""
|
||||
owner_username = ""
|
||||
included = {}
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if "users" in included:
|
||||
owner_id = str(included["users"][0]["id"])
|
||||
owner_username = included["users"][0]["username"]
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
if response.data:
|
||||
data_dict = ResponseDataSerializer.serialize_dict(response.data)
|
||||
return data_dict, included, meta, owner_id, owner_username
|
||||
|
||||
raise Exception("List not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, owner_id, owner_username = self.get_list(
|
||||
credentials,
|
||||
input_data.list_id,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
|
||||
yield "id", str(list_data["id"])
|
||||
yield "name", list_data["name"]
|
||||
if owner_id:
|
||||
yield "owner_id", owner_id
|
||||
if owner_username:
|
||||
yield "owner_username", owner_username
|
||||
yield "data", {"id": list_data["id"], "name": list_data["name"]}
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetOwnedListsBlock(Block):
|
||||
"""
|
||||
Gets all Lists owned by the specified user
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "list.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The user ID whose owned Lists to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results (default 100)",
|
||||
advanced=True,
|
||||
default=10,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
placeholder="Enter pagination token",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
list_ids: list[str] = SchemaField(description="List ids of the owned lists")
|
||||
list_names: list[str] = SchemaField(description="List names of the owned lists")
|
||||
next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# Complete outputs
|
||||
data: list[dict] = SchemaField(description="Complete owned lists data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the response")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b6bdb26-a62f-11ef-a9ce-ff89c2568726",
|
||||
description="This block retrieves all Lists owned by a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetOwnedListsBlock.Input,
|
||||
output_schema=TwitterGetOwnedListsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "2244994945",
|
||||
"max_results": 10,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_ids", ["84839422"]),
|
||||
("list_names", ["Official Twitter Accounts"]),
|
||||
("data", [{"id": "84839422", "name": "Official Twitter Accounts"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_owned_lists": lambda *args, **kwargs: (
|
||||
[{"id": "84839422", "name": "Official Twitter Accounts"}],
|
||||
{},
|
||||
{},
|
||||
["84839422"],
|
||||
["Official Twitter Accounts"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_owned_lists(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_owned_lists(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
list_ids = []
|
||||
list_names = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
list_ids = [
|
||||
str(item.id) for item in response.data if hasattr(item, "id")
|
||||
]
|
||||
list_names = [
|
||||
item.name for item in response.data if hasattr(item, "name")
|
||||
]
|
||||
|
||||
return data, included, meta, list_ids, list_names, next_token
|
||||
|
||||
raise Exception("User have no owned list")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, list_ids, list_names, next_token = (
|
||||
self.get_owned_lists(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if list_ids:
|
||||
yield "list_ids", list_ids
|
||||
if list_names:
|
||||
yield "list_names", list_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if list_data:
|
||||
yield "data", list_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,527 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
ListExpansionsBuilder,
|
||||
UserExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ListExpansionInputs,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterRemoveListMemberBlock(Block):
|
||||
"""
|
||||
Removes a member from a Twitter List that the authenticated user owns
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to remove the member from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to remove from the List",
|
||||
placeholder="Enter user ID to remove",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the member was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the removal failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5a3d1320-a62f-11ef-b7ce-a79e7656bcb0",
|
||||
description="This block removes a specified user from a Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRemoveListMemberBlock.Input,
|
||||
output_schema=TwitterRemoveListMemberBlock.Output,
|
||||
test_input={
|
||||
"list_id": "123456789",
|
||||
"user_id": "987654321",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"remove_list_member": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_list_member(credentials: TwitterCredentials, list_id: str, user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
client.remove_list_member(id=list_id, user_id=user_id, user_auth=False)
|
||||
return True
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_list_member(
|
||||
credentials, input_data.list_id, input_data.user_id
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterAddListMemberBlock(Block):
|
||||
"""
|
||||
Adds a member to a Twitter List that the authenticated user owns
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to add the member to",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to add to the List",
|
||||
placeholder="Enter user ID to add",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the member was successfully added"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the addition failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ee8284e-a62f-11ef-84e4-8f6e2cbf0ddb",
|
||||
description="This block adds a specified user to a Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterAddListMemberBlock.Input,
|
||||
output_schema=TwitterAddListMemberBlock.Output,
|
||||
test_input={
|
||||
"list_id": "123456789",
|
||||
"user_id": "987654321",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"add_list_member": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_list_member(credentials: TwitterCredentials, list_id: str, user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
client.add_list_member(id=list_id, user_id=user_id, user_auth=False)
|
||||
return True
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.add_list_member(
|
||||
credentials, input_data.list_id, input_data.user_id
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetListMembersBlock(Block):
|
||||
"""
|
||||
Gets the members of a specified Twitter List
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to get members from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of member user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of member usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(
|
||||
description="Complete user data for list members"
|
||||
)
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dba046e-a62f-11ef-b69a-87240c84b4c7",
|
||||
description="This block retrieves the members of a specified Twitter List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListMembersBlock.Input,
|
||||
output_schema=TwitterGetListMembersBlock.Output,
|
||||
test_input={
|
||||
"list_id": "123456789",
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("usernames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_list_members": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list_members(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": list_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list_members(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
next_token = None
|
||||
user_ids = []
|
||||
usernames = []
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
user_ids = [str(user.id) for user in response.data]
|
||||
usernames = [user.username for user in response.data]
|
||||
return user_ids, usernames, data, included, meta, next_token
|
||||
|
||||
raise Exception("List members not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, included, meta, next_token = self.get_list_members(
|
||||
credentials,
|
||||
input_data.list_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetListMembershipsBlock(Block):
|
||||
"""
|
||||
Gets all Lists that a specified user is a member of
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user whose List memberships to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results",
|
||||
advanced=True,
|
||||
default=10,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination of results",
|
||||
placeholder="Enter pagination token",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list_ids: list[str] = SchemaField(description="List of list IDs")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="List membership data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about pagination")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="46e6429c-a62f-11ef-81c0-2b55bc7823ba",
|
||||
description="This block retrieves all Lists that a specified user is a member of.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListMembershipsBlock.Input,
|
||||
output_schema=TwitterGetListMembershipsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "123456789",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_ids", ["84839422"]),
|
||||
("data", [{"id": "84839422"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_list_memberships": lambda *args, **kwargs: (
|
||||
[{"id": "84839422"}],
|
||||
{},
|
||||
{},
|
||||
["84839422"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list_memberships(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list_memberships(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
next_token = None
|
||||
list_ids = []
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
list_ids = [str(lst.id) for lst in response.data]
|
||||
return data, included, meta, list_ids, next_token
|
||||
|
||||
raise Exception("List memberships not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, meta, list_ids, next_token = self.get_list_memberships(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
|
||||
if list_ids:
|
||||
yield "list_ids", list_ids
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,217 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetListTweetsBlock(Block):
|
||||
"""
|
||||
Gets tweets from a specified Twitter list
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List whose Tweets you would like to retrieve",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for paginating through results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
tweet_ids: list[str] = SchemaField(description="List of tweet IDs")
|
||||
texts: list[str] = SchemaField(description="List of tweet texts")
|
||||
next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# Complete outputs
|
||||
data: list[dict] = SchemaField(description="Complete list tweets data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Response metadata including pagination tokens"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6657edb0-a62f-11ef-8c10-0326d832467d",
|
||||
description="This block retrieves tweets from a specified Twitter list.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListTweetsBlock.Input,
|
||||
output_schema=TwitterGetListTweetsBlock.Output,
|
||||
test_input={
|
||||
"list_id": "84839422",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_ids", ["1234567890"]),
|
||||
("texts", ["Test tweet"]),
|
||||
("data", [{"id": "1234567890", "text": "Test tweet"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_list_tweets": lambda *args, **kwargs: (
|
||||
[{"id": "1234567890", "text": "Test tweet"}],
|
||||
{},
|
||||
{},
|
||||
["1234567890"],
|
||||
["Test tweet"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": list_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list_tweets(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
tweet_ids = []
|
||||
texts = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
tweet_ids = [str(item.id) for item in response.data]
|
||||
texts = [item.text for item in response.data]
|
||||
|
||||
return data, included, meta, tweet_ids, texts, next_token
|
||||
|
||||
raise Exception("No tweets found in this list")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, tweet_ids, texts, next_token = (
|
||||
self.get_list_tweets(
|
||||
credentials,
|
||||
input_data.list_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if tweet_ids:
|
||||
yield "tweet_ids", tweet_ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if list_data:
|
||||
yield "data", list_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,278 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterDeleteListBlock(Block):
|
||||
"""
|
||||
Deletes a Twitter List owned by the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be deleted",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the deletion was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="843c6892-a62f-11ef-a5c8-b71239a78d3b",
|
||||
description="This block deletes a specified Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterDeleteListBlock.Input,
|
||||
output_schema=TwitterDeleteListBlock.Output,
|
||||
test_input={"list_id": "1234567890", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.delete_list(id=list_id, user_auth=False)
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterUpdateListBlock(Block):
|
||||
"""
|
||||
Updates a Twitter List owned by the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be updated",
|
||||
placeholder="Enter list ID",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
name: str | None = SchemaField(
|
||||
description="New name for the List",
|
||||
placeholder="Enter list name",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
description: str | None = SchemaField(
|
||||
description="New description for the List",
|
||||
placeholder="Enter list description",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the update was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7d12630a-a62f-11ef-90c9-8f5a996612c3",
|
||||
description="This block updates a specified Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUpdateListBlock.Input,
|
||||
output_schema=TwitterUpdateListBlock.Output,
|
||||
test_input={
|
||||
"list_id": "1234567890",
|
||||
"name": "Updated List Name",
|
||||
"description": "Updated List Description",
|
||||
"private": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"update_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_list(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.update_list(
|
||||
id=list_id,
|
||||
name=None if name == "" else name,
|
||||
description=None if description == "" else description,
|
||||
user_auth=False,
|
||||
)
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.update_list(
|
||||
credentials, input_data.list_id, input_data.name, input_data.description
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterCreateListBlock(Block):
|
||||
"""
|
||||
Creates a Twitter List owned by the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "offline.access"]
|
||||
)
|
||||
|
||||
name: str = SchemaField(
|
||||
description="The name of the List to be created",
|
||||
placeholder="Enter list name",
|
||||
advanced=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
description: str | None = SchemaField(
|
||||
description="Description of the List",
|
||||
placeholder="Enter list description",
|
||||
advanced=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
private: bool = SchemaField(
|
||||
description="Whether the List should be private",
|
||||
advanced=False,
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the created list")
|
||||
list_id: str = SchemaField(description="ID of the created list")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="724148ba-a62f-11ef-89ba-5349b813ef5f",
|
||||
description="This block creates a new Twitter List for the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterCreateListBlock.Input,
|
||||
output_schema=TwitterCreateListBlock.Output,
|
||||
test_input={
|
||||
"name": "New List Name",
|
||||
"description": "New List Description",
|
||||
"private": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_id", "1234567890"),
|
||||
("url", "https://twitter.com/i/lists/1234567890"),
|
||||
],
|
||||
test_mock={"create_list": lambda *args, **kwargs: ("1234567890")},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_list(
|
||||
credentials: TwitterCredentials,
|
||||
name: str,
|
||||
description: str | None,
|
||||
private: bool,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.create_list(
|
||||
name=None if name == "" else name,
|
||||
description=None if description == "" else description,
|
||||
private=private,
|
||||
user_auth=False,
|
||||
),
|
||||
)
|
||||
|
||||
list_id = str(response.data["id"])
|
||||
|
||||
return list_id
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_id = self.create_list(
|
||||
credentials, input_data.name, input_data.description, input_data.private
|
||||
)
|
||||
yield "list_id", list_id
|
||||
yield "url", f"https://twitter.com/i/lists/{list_id}"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,285 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import ListExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ListExpansionInputs,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnpinListBlock(Block):
|
||||
"""
|
||||
Enables the authenticated user to unpin a List.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unpin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unpin was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a099c034-a62f-11ef-9622-47d0ceb73555",
|
||||
description="This block allows the authenticated user to unpin a specified List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnpinListBlock.Input,
|
||||
output_schema=TwitterUnpinListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"unpin_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unpin_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unpin_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unpin_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterPinListBlock(Block):
|
||||
"""
|
||||
Enables the authenticated user to pin a List.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to pin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the pin was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ec16e48-a62f-11ef-9f35-f3d6de43a802",
|
||||
description="This block allows the authenticated user to pin a specified List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterPinListBlock.Input,
|
||||
output_schema=TwitterPinListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"pin_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pin_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.pin_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.pin_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetPinnedListsBlock(Block):
|
||||
"""
|
||||
Returns the Lists pinned by the authenticated user.
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["lists.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list_ids: list[str] = SchemaField(description="List IDs of the pinned lists")
|
||||
list_names: list[str] = SchemaField(
|
||||
description="List names of the pinned lists"
|
||||
)
|
||||
|
||||
data: list[dict] = SchemaField(
|
||||
description="Response data containing pinned lists"
|
||||
)
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the response")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="97e03aae-a62f-11ef-bc53-5b89cb02888f",
|
||||
description="This block returns the Lists pinned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetPinnedListsBlock.Input,
|
||||
output_schema=TwitterGetPinnedListsBlock.Output,
|
||||
test_input={
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_ids", ["84839422"]),
|
||||
("list_names", ["Twitter List"]),
|
||||
("data", [{"id": "84839422", "name": "Twitter List"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_pinned_lists": lambda *args, **kwargs: (
|
||||
[{"id": "84839422", "name": "Twitter List"}],
|
||||
{},
|
||||
{},
|
||||
["84839422"],
|
||||
["Twitter List"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pinned_lists(
|
||||
credentials: TwitterCredentials,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {"user_auth": False}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_pinned_lists(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
list_ids = []
|
||||
list_names = []
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
list_ids = [str(item.id) for item in response.data]
|
||||
list_names = [item.name for item in response.data]
|
||||
return data, included, meta, list_ids, list_names
|
||||
|
||||
raise Exception("Lists not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, list_ids, list_names = self.get_pinned_lists(
|
||||
credentials,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
|
||||
if list_ids:
|
||||
yield "list_ids", list_ids
|
||||
if list_names:
|
||||
yield "list_names", list_names
|
||||
if list_data:
|
||||
yield "data", list_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,195 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import SpaceExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
SpaceExpansionInputs,
|
||||
SpaceExpansionsFilter,
|
||||
SpaceFieldsFilter,
|
||||
SpaceStatesFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterSearchSpacesBlock(Block):
|
||||
"""
|
||||
Returns live or scheduled Spaces matching specified search terms [for a week only]
|
||||
"""
|
||||
|
||||
class Input(SpaceExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
query: str = SchemaField(
|
||||
description="Search term to find in Space titles",
|
||||
placeholder="Enter search query",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
state: SpaceStatesFilter = SchemaField(
|
||||
description="Type of Spaces to return (live, scheduled, or all)",
|
||||
placeholder="Enter state filter",
|
||||
default=SpaceStatesFilter.all,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of space IDs")
|
||||
titles: list[str] = SchemaField(description="List of space titles")
|
||||
host_ids: list = SchemaField(description="List of host IDs")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aaefdd48-a62f-11ef-a73c-3f44df63e276",
|
||||
description="This block searches for Twitter Spaces based on specified terms.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterSearchSpacesBlock.Input,
|
||||
output_schema=TwitterSearchSpacesBlock.Output,
|
||||
test_input={
|
||||
"query": "tech",
|
||||
"max_results": 1,
|
||||
"state": "live",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"space_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1234"]),
|
||||
("titles", ["Tech Talk"]),
|
||||
("host_ids", ["5678"]),
|
||||
("data", [{"id": "1234", "title": "Tech Talk", "host_ids": ["5678"]}]),
|
||||
],
|
||||
test_mock={
|
||||
"search_spaces": lambda *args, **kwargs: (
|
||||
[{"id": "1234", "title": "Tech Talk", "host_ids": ["5678"]}],
|
||||
{},
|
||||
{},
|
||||
["1234"],
|
||||
["Tech Talk"],
|
||||
["5678"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_spaces(
|
||||
credentials: TwitterCredentials,
|
||||
query: str,
|
||||
max_results: int | None,
|
||||
state: SpaceStatesFilter,
|
||||
expansions: SpaceExpansionsFilter | None,
|
||||
space_fields: SpaceFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {"query": query, "max_results": max_results, "state": state.value}
|
||||
|
||||
params = (
|
||||
SpaceExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_space_fields(space_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.search_spaces(**params))
|
||||
|
||||
meta = {}
|
||||
next_token = ""
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
if "next_token" in meta:
|
||||
next_token = meta["next_token"]
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
ids = [str(space["id"]) for space in response.data if "id" in space]
|
||||
titles = [space["title"] for space in data if "title" in space]
|
||||
host_ids = [space["host_ids"] for space in data if "host_ids" in space]
|
||||
|
||||
return data, included, meta, ids, titles, host_ids, next_token
|
||||
|
||||
raise Exception("Spaces not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, meta, ids, titles, host_ids, next_token = (
|
||||
self.search_spaces(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.max_results,
|
||||
input_data.state,
|
||||
input_data.expansions,
|
||||
input_data.space_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if titles:
|
||||
yield "titles", titles
|
||||
if host_ids:
|
||||
yield "host_ids", host_ids
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "includes", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,651 @@
|
|||
from typing import Literal, Union, cast
|
||||
|
||||
import tweepy
|
||||
from pydantic import BaseModel
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
SpaceExpansionsBuilder,
|
||||
TweetExpansionsBuilder,
|
||||
UserExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
SpaceExpansionInputs,
|
||||
SpaceExpansionsFilter,
|
||||
SpaceFieldsFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SpaceList(BaseModel):
|
||||
discriminator: Literal["space_list"]
|
||||
space_ids: list[str] = SchemaField(
|
||||
description="List of Space IDs to lookup (up to 100)",
|
||||
placeholder="Enter Space IDs",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class UserList(BaseModel):
|
||||
discriminator: Literal["user_list"]
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup their Spaces (up to 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class TwitterGetSpacesBlock(Block):
|
||||
"""
|
||||
Gets information about multiple Twitter Spaces specified by Space IDs or creator user IDs
|
||||
"""
|
||||
|
||||
class Input(SpaceExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
identifier: Union[SpaceList, UserList] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Choose whether to lookup spaces by their IDs or by creator user IDs",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
ids: list[str] = SchemaField(description="List of space IDs")
|
||||
titles: list[str] = SchemaField(description="List of space titles")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d75bd7d8-a62f-11ef-b0d8-c7a9496f617f",
|
||||
description="This block retrieves information about multiple Twitter Spaces.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpacesBlock.Input,
|
||||
output_schema=TwitterGetSpacesBlock.Output,
|
||||
test_input={
|
||||
"identifier": {
|
||||
"discriminator": "space_list",
|
||||
"space_ids": ["1DXxyRYNejbKM"],
|
||||
},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"space_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1DXxyRYNejbKM"]),
|
||||
("titles", ["Test Space"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_id": "1234567",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_spaces": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_id": "1234567",
|
||||
}
|
||||
],
|
||||
{},
|
||||
["1DXxyRYNejbKM"],
|
||||
["Test Space"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_spaces(
|
||||
credentials: TwitterCredentials,
|
||||
identifier: Union[SpaceList, UserList],
|
||||
expansions: SpaceExpansionsFilter | None,
|
||||
space_fields: SpaceFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"ids": (
|
||||
identifier.space_ids if isinstance(identifier, SpaceList) else None
|
||||
),
|
||||
"user_ids": (
|
||||
identifier.user_ids if isinstance(identifier, UserList) else None
|
||||
),
|
||||
}
|
||||
|
||||
params = (
|
||||
SpaceExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_space_fields(space_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_spaces(**params))
|
||||
|
||||
ids = []
|
||||
titles = []
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
ids = [space["id"] for space in data if "id" in space]
|
||||
titles = [space["title"] for space in data if "title" in space]
|
||||
|
||||
return data, included, ids, titles
|
||||
|
||||
raise Exception("No spaces found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, ids, titles = self.get_spaces(
|
||||
credentials,
|
||||
input_data.identifier,
|
||||
input_data.expansions,
|
||||
input_data.space_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if titles:
|
||||
yield "titles", titles
|
||||
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "includes", included
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetSpaceByIdBlock(Block):
|
||||
"""
|
||||
Gets information about a single Twitter Space specified by Space ID
|
||||
"""
|
||||
|
||||
class Input(SpaceExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
id: str = SchemaField(description="Space ID")
|
||||
title: str = SchemaField(description="Space title")
|
||||
host_ids: list[str] = SchemaField(description="Host ID")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: dict = SchemaField(description="Complete space data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c79700de-a62f-11ef-ab20-fb32bf9d5a9d",
|
||||
description="This block retrieves information about a single Twitter Space.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpaceByIdBlock.Input,
|
||||
output_schema=TwitterGetSpaceByIdBlock.Output,
|
||||
test_input={
|
||||
"space_id": "1DXxyRYNejbKM",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"space_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "1DXxyRYNejbKM"),
|
||||
("title", "Test Space"),
|
||||
("host_ids", ["1234567"]),
|
||||
(
|
||||
"data",
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_ids": ["1234567"],
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_space": lambda *args, **kwargs: (
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_ids": ["1234567"],
|
||||
},
|
||||
{},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_space(
|
||||
credentials: TwitterCredentials,
|
||||
space_id: str,
|
||||
expansions: SpaceExpansionsFilter | None,
|
||||
space_fields: SpaceFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": space_id,
|
||||
}
|
||||
|
||||
params = (
|
||||
SpaceExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_space_fields(space_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_space(**params))
|
||||
|
||||
includes = {}
|
||||
if response.includes:
|
||||
for key, value in response.includes.items():
|
||||
if isinstance(value, list):
|
||||
includes[key] = [
|
||||
item.data if hasattr(item, "data") else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
includes[key] = value.data if hasattr(value, "data") else value
|
||||
|
||||
data = {}
|
||||
if response.data:
|
||||
for key, value in response.data.items():
|
||||
if isinstance(value, list):
|
||||
data[key] = [
|
||||
item.data if hasattr(item, "data") else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
data[key] = value.data if hasattr(value, "data") else value
|
||||
|
||||
return data, includes
|
||||
|
||||
raise Exception("Space not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
space_data, includes = self.get_space(
|
||||
credentials,
|
||||
input_data.space_id,
|
||||
input_data.expansions,
|
||||
input_data.space_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
# Common outputs
|
||||
if space_data:
|
||||
if "id" in space_data:
|
||||
yield "id", space_data.get("id")
|
||||
|
||||
if "title" in space_data:
|
||||
yield "title", space_data.get("title")
|
||||
|
||||
if "host_ids" in space_data:
|
||||
yield "host_ids", space_data.get("host_ids")
|
||||
|
||||
if space_data:
|
||||
yield "data", space_data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
# Not tested yet, might have some problem
|
||||
class TwitterGetSpaceBuyersBlock(Block):
|
||||
"""
|
||||
Gets list of users who purchased a ticket to the requested Space
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup buyers for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
buyer_ids: list[str] = SchemaField(description="List of buyer IDs")
|
||||
usernames: list[str] = SchemaField(description="List of buyer usernames")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space buyers data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c1c121a8-a62f-11ef-8b0e-d7b85f96a46f",
|
||||
description="This block retrieves a list of users who purchased tickets to a Twitter Space.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpaceBuyersBlock.Input,
|
||||
output_schema=TwitterGetSpaceBuyersBlock.Output,
|
||||
test_input={
|
||||
"space_id": "1DXxyRYNejbKM",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("buyer_ids", ["2244994945"]),
|
||||
("usernames", ["testuser"]),
|
||||
(
|
||||
"data",
|
||||
[{"id": "2244994945", "username": "testuser", "name": "Test User"}],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_space_buyers": lambda *args, **kwargs: (
|
||||
[{"id": "2244994945", "username": "testuser", "name": "Test User"}],
|
||||
{},
|
||||
["2244994945"],
|
||||
["testuser"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_space_buyers(
|
||||
credentials: TwitterCredentials,
|
||||
space_id: str,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": space_id,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_space_buyers(**params))
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
buyer_ids = [buyer["id"] for buyer in data]
|
||||
usernames = [buyer["username"] for buyer in data]
|
||||
|
||||
return data, included, buyer_ids, usernames
|
||||
|
||||
raise Exception("No buyers found for this Space")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
buyers_data, included, buyer_ids, usernames = self.get_space_buyers(
|
||||
credentials,
|
||||
input_data.space_id,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if buyer_ids:
|
||||
yield "buyer_ids", buyer_ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
|
||||
if buyers_data:
|
||||
yield "data", buyers_data
|
||||
if included:
|
||||
yield "includes", included
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetSpaceTweetsBlock(Block):
|
||||
"""
|
||||
Gets list of Tweets shared in the requested Space
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup tweets for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
tweet_ids: list[str] = SchemaField(description="List of tweet IDs")
|
||||
texts: list[str] = SchemaField(description="List of tweet texts")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space tweets data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Response metadata")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b69731e6-a62f-11ef-b2d4-1bf14dd6aee4",
|
||||
description="This block retrieves tweets shared in a Twitter Space.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpaceTweetsBlock.Input,
|
||||
output_schema=TwitterGetSpaceTweetsBlock.Output,
|
||||
test_input={
|
||||
"space_id": "1DXxyRYNejbKM",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_ids", ["1234567890"]),
|
||||
("texts", ["Test tweet"]),
|
||||
("data", [{"id": "1234567890", "text": "Test tweet"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_space_tweets": lambda *args, **kwargs: (
|
||||
[{"id": "1234567890", "text": "Test tweet"}], # data
|
||||
{},
|
||||
["1234567890"],
|
||||
["Test tweet"],
|
||||
{},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_space_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
space_id: str,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": space_id,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_space_tweets(**params))
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
tweet_ids = [str(tweet["id"]) for tweet in data]
|
||||
texts = [tweet["text"] for tweet in data]
|
||||
|
||||
meta = response.meta or {}
|
||||
|
||||
return data, included, tweet_ids, texts, meta
|
||||
|
||||
raise Exception("No tweets found for this Space")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tweets_data, included, tweet_ids, texts, meta = self.get_space_tweets(
|
||||
credentials,
|
||||
input_data.space_id,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if tweet_ids:
|
||||
yield "tweet_ids", tweet_ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
|
||||
if tweets_data:
|
||||
yield "data", tweets_data
|
||||
if included:
|
||||
yield "includes", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,20 @@
|
|||
import tweepy
|
||||
|
||||
|
||||
def handle_tweepy_exception(e: Exception) -> str:
|
||||
if isinstance(e, tweepy.BadRequest):
|
||||
return f"Bad Request (400): {str(e)}"
|
||||
elif isinstance(e, tweepy.Unauthorized):
|
||||
return f"Unauthorized (401): {str(e)}"
|
||||
elif isinstance(e, tweepy.Forbidden):
|
||||
return f"Forbidden (403): {str(e)}"
|
||||
elif isinstance(e, tweepy.NotFound):
|
||||
return f"Not Found (404): {str(e)}"
|
||||
elif isinstance(e, tweepy.TooManyRequests):
|
||||
return f"Too Many Requests (429): {str(e)}"
|
||||
elif isinstance(e, tweepy.TwitterServerError):
|
||||
return f"Twitter Server Error (5xx): {str(e)}"
|
||||
elif isinstance(e, tweepy.TweepyException):
|
||||
return f"Tweepy Error: {str(e)}"
|
||||
else:
|
||||
return f"Unexpected error: {str(e)}"
|
|
@ -0,0 +1,372 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterBookmarkTweetBlock(Block):
|
||||
"""
|
||||
Bookmark a tweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "bookmark.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to bookmark",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the bookmark was successful")
|
||||
error: str = SchemaField(description="Error message if the bookmark failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f33d67be-a62f-11ef-a797-ff83ec29ee8e",
|
||||
description="This block bookmarks a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterBookmarkTweetBlock.Input,
|
||||
output_schema=TwitterBookmarkTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"bookmark_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bookmark_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.bookmark(tweet_id)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.bookmark_tweet(credentials, input_data.tweet_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetBookmarkedTweetsBlock(Block):
|
||||
"""
|
||||
Get All your bookmarked tweets from Twitter
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "bookmark.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
id: list[str] = SchemaField(description="All Tweet IDs")
|
||||
text: list[str] = SchemaField(description="All Tweet texts")
|
||||
userId: list[str] = SchemaField(description="IDs of the tweet authors")
|
||||
userName: list[str] = SchemaField(description="Usernames of the tweet authors")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed26783e-a62f-11ef-9a21-c77c57dd8a1f",
|
||||
description="This block retrieves bookmarked tweets from Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetBookmarkedTweetsBlock.Input,
|
||||
output_schema=TwitterGetBookmarkedTweetsBlock.Output,
|
||||
test_input={
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", ["1234567890"]),
|
||||
("text", ["Test tweet"]),
|
||||
("userId", ["12345"]),
|
||||
("userName", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "text": "Test tweet"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_bookmarked_tweets": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["Test tweet"],
|
||||
["12345"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "text": "Test tweet"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_bookmarked_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_bookmarks(**params),
|
||||
)
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
for user in included["users"]:
|
||||
user_ids.append(str(user["id"]))
|
||||
user_names.append(user["username"])
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("No bookmarked tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_bookmarked_tweets(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "id", ids
|
||||
if texts:
|
||||
yield "text", texts
|
||||
if user_ids:
|
||||
yield "userId", user_ids
|
||||
if user_names:
|
||||
yield "userName", user_names
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterRemoveBookmarkTweetBlock(Block):
|
||||
"""
|
||||
Remove a bookmark for a tweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "bookmark.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to remove bookmark from",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the bookmark was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the bookmark removal failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e4100684-a62f-11ef-9be9-770cb41a2616",
|
||||
description="This block removes a bookmark from a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRemoveBookmarkTweetBlock.Input,
|
||||
output_schema=TwitterRemoveBookmarkTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"remove_bookmark_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_bookmark_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.remove_bookmark(tweet_id)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_bookmark_tweet(credentials, input_data.tweet_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,154 @@
|
|||
import tweepy
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterHideReplyBlock(Block):
|
||||
"""
|
||||
Hides a reply of one of your tweets
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.moderate.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet reply to hide",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="07d58b3e-a630-11ef-a030-93701d1a465e",
|
||||
description="This block hides a reply to a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterHideReplyBlock.Input,
|
||||
output_schema=TwitterHideReplyBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"hide_reply": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hide_reply(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.hide_reply(id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.hide_reply(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterUnhideReplyBlock(Block):
|
||||
"""
|
||||
Unhides a reply to a tweet
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.moderate.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet reply to unhide",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcf9e4e4-a62f-11ef-9d85-57d3d06b616a",
|
||||
description="This block unhides a reply to a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnhideReplyBlock.Input,
|
||||
output_schema=TwitterUnhideReplyBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unhide_reply": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unhide_reply(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unhide_reply(id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unhide_reply(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,576 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
TweetExpansionsBuilder,
|
||||
UserExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterLikeTweetBlock(Block):
|
||||
"""
|
||||
Likes a tweet
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "like.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to like",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4d0b4c5c-a630-11ef-8e08-1b14c507b347",
|
||||
description="This block likes a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterLikeTweetBlock.Input,
|
||||
output_schema=TwitterLikeTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"like_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def like_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.like(tweet_id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.like_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetLikingUsersBlock(Block):
|
||||
"""
|
||||
Gets information about users who liked a one of your tweet
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "like.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to get liking users for",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for getting next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
id: list[str] = SchemaField(description="All User IDs who liked the tweet")
|
||||
username: list[str] = SchemaField(
|
||||
description="All User usernames who liked the tweet"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34275000-a630-11ef-b01e-5f00d9077c08",
|
||||
description="This block gets information about users who liked a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetLikingUsersBlock.Input,
|
||||
output_schema=TwitterGetLikingUsersBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", ["1234567890"]),
|
||||
("username", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "username": "testuser"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_liking_users": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_liking_users(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": tweet_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_liking_users(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No liking users found")
|
||||
|
||||
meta = {}
|
||||
user_ids = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
user_ids = [str(user.id) for user in response.data]
|
||||
usernames = [user.username for user in response.data]
|
||||
|
||||
return user_ids, usernames, data, included, meta, next_token
|
||||
|
||||
raise Exception("No liking users found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, included, meta, next_token = self.get_liking_users(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "id", ids
|
||||
if usernames:
|
||||
yield "username", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetLikedTweetsBlock(Block):
|
||||
"""
|
||||
Gets information about tweets liked by you
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "like.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="ID of the user to get liked tweets for",
|
||||
placeholder="Enter user ID",
|
||||
)
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (5-100)",
|
||||
placeholder="100",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for getting next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="All Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="292e7c78-a630-11ef-9f40-df5dffaca106",
|
||||
description="This block gets information about tweets liked by a user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetLikedTweetsBlock.Input,
|
||||
output_schema=TwitterGetLikedTweetsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "1234567890",
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("texts", ["Tweet 1", "Tweet 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_liked_tweets": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["Tweet 1", "Tweet 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_liked_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_liked_tweets(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No liked tweets found")
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in response.includes:
|
||||
user_ids = [str(user["id"]) for user in response.includes["users"]]
|
||||
user_names = [
|
||||
user["username"] for user in response.includes["users"]
|
||||
]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("No liked tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_liked_tweets(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterUnlikeTweetBlock(Block):
|
||||
"""
|
||||
Unlikes a tweet that was previously liked
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "like.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to unlike",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1ed5eab8-a630-11ef-8e21-cbbbc80cbb85",
|
||||
description="This block unlikes a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnlikeTweetBlock.Input,
|
||||
output_schema=TwitterUnlikeTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unlike_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unlike_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unlike(tweet_id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unlike_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,545 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
import tweepy
|
||||
from pydantic import BaseModel
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
TweetDurationBuilder,
|
||||
TweetExpansionsBuilder,
|
||||
TweetPostBuilder,
|
||||
TweetSearchBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetReplySettingsFilter,
|
||||
TweetTimeWindowInputs,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class Media(BaseModel):
|
||||
discriminator: Literal["media"]
|
||||
media_ids: Optional[List[str]] = None
|
||||
media_tagged_user_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
class DeepLink(BaseModel):
|
||||
discriminator: Literal["deep_link"]
|
||||
direct_message_deep_link: Optional[str] = None
|
||||
|
||||
|
||||
class Poll(BaseModel):
|
||||
discriminator: Literal["poll"]
|
||||
poll_options: Optional[List[str]] = None
|
||||
poll_duration_minutes: Optional[int] = None
|
||||
|
||||
|
||||
class Place(BaseModel):
|
||||
discriminator: Literal["place"]
|
||||
place_id: Optional[str] = None
|
||||
|
||||
|
||||
class Quote(BaseModel):
|
||||
discriminator: Literal["quote"]
|
||||
quote_tweet_id: Optional[str] = None
|
||||
|
||||
|
||||
class TwitterPostTweetBlock(Block):
|
||||
"""
|
||||
Create a tweet on Twitter with the option to include one additional element such as a media, quote, or deep link.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_text: str | None = SchemaField(
|
||||
description="Text of the tweet to post",
|
||||
placeholder="Enter your tweet",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
for_super_followers_only: bool = SchemaField(
|
||||
description="Tweet exclusively for Super Followers",
|
||||
placeholder="Enter for super followers only",
|
||||
advanced=True,
|
||||
default=False,
|
||||
)
|
||||
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Additional tweet data (media, deep link, poll, place or quote)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
exclude_reply_user_ids: Optional[List[str]] = SchemaField(
|
||||
description="User IDs to exclude from reply Tweet thread. [ex - 6253282]",
|
||||
placeholder="Enter user IDs to exclude",
|
||||
advanced=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
in_reply_to_tweet_id: Optional[str] = SchemaField(
|
||||
description="Tweet ID being replied to. Please note that in_reply_to_tweet_id needs to be in the request if exclude_reply_user_ids is present",
|
||||
default=None,
|
||||
placeholder="Enter in reply to tweet ID",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
reply_settings: TweetReplySettingsFilter = SchemaField(
|
||||
description="Who can reply to the Tweet (mentionedUsers or following)",
|
||||
placeholder="Enter reply settings",
|
||||
advanced=True,
|
||||
default=TweetReplySettingsFilter(All_Users=True),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
tweet_id: str = SchemaField(description="ID of the created tweet")
|
||||
tweet_url: str = SchemaField(description="URL to the tweet")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the tweet posting failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7bb0048a-a630-11ef-aeb8-abc0dadb9b12",
|
||||
description="This block posts a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterPostTweetBlock.Input,
|
||||
output_schema=TwitterPostTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_text": "This is a test tweet.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"attachment": {
|
||||
"discriminator": "deep_link",
|
||||
"direct_message_deep_link": "https://twitter.com/messages/compose",
|
||||
},
|
||||
"for_super_followers_only": False,
|
||||
"exclude_reply_user_ids": [],
|
||||
"in_reply_to_tweet_id": "",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_id", "1234567890"),
|
||||
("tweet_url", "https://twitter.com/user/status/1234567890"),
|
||||
],
|
||||
test_mock={
|
||||
"post_tweet": lambda *args, **kwargs: (
|
||||
"1234567890",
|
||||
"https://twitter.com/user/status/1234567890",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def post_tweet(
|
||||
self,
|
||||
credentials: TwitterCredentials,
|
||||
input_txt: str | None,
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None,
|
||||
for_super_followers_only: bool,
|
||||
exclude_reply_user_ids: Optional[List[str]],
|
||||
in_reply_to_tweet_id: Optional[str],
|
||||
reply_settings: TweetReplySettingsFilter,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = (
|
||||
TweetPostBuilder()
|
||||
.add_text(input_txt)
|
||||
.add_super_followers(for_super_followers_only)
|
||||
.add_reply_settings(
|
||||
exclude_reply_user_ids or [],
|
||||
in_reply_to_tweet_id or "",
|
||||
reply_settings,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(attachment, Media):
|
||||
params.add_media(
|
||||
attachment.media_ids or [], attachment.media_tagged_user_ids or []
|
||||
)
|
||||
elif isinstance(attachment, DeepLink):
|
||||
params.add_deep_link(attachment.direct_message_deep_link or "")
|
||||
elif isinstance(attachment, Poll):
|
||||
params.add_poll_options(attachment.poll_options or [])
|
||||
params.add_poll_duration(attachment.poll_duration_minutes or 0)
|
||||
elif isinstance(attachment, Place):
|
||||
params.add_place(attachment.place_id or "")
|
||||
elif isinstance(attachment, Quote):
|
||||
params.add_quote(attachment.quote_tweet_id or "")
|
||||
|
||||
tweet = cast(Response, client.create_tweet(**params.build()))
|
||||
|
||||
if not tweet.data:
|
||||
raise Exception("Failed to create tweet")
|
||||
|
||||
tweet_id = tweet.data["id"]
|
||||
tweet_url = f"https://twitter.com/user/status/{tweet_id}"
|
||||
return str(tweet_id), tweet_url
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tweet_id, tweet_url = self.post_tweet(
|
||||
credentials,
|
||||
input_data.tweet_text,
|
||||
input_data.attachment,
|
||||
input_data.for_super_followers_only,
|
||||
input_data.exclude_reply_user_ids,
|
||||
input_data.in_reply_to_tweet_id,
|
||||
input_data.reply_settings,
|
||||
)
|
||||
yield "tweet_id", tweet_id
|
||||
yield "tweet_url", tweet_url
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterDeleteTweetBlock(Block):
|
||||
"""
|
||||
Deletes a tweet on Twitter using twitter Id
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to delete",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the tweet was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the tweet deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="761babf0-a630-11ef-a03d-abceb082f58f",
|
||||
description="This block deletes a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterDeleteTweetBlock.Input,
|
||||
output_schema=TwitterDeleteTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_tweet(credentials: TwitterCredentials, tweet_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
client.delete_tweet(id=tweet_id, user_auth=False)
|
||||
return True
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterSearchRecentTweetsBlock(Block):
|
||||
"""
|
||||
Searches all public Tweets in Twitter history
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
query: str = SchemaField(
|
||||
description="Search query (up to 1024 characters)",
|
||||
placeholder="Enter search query",
|
||||
)
|
||||
|
||||
max_results: int = SchemaField(
|
||||
description="Maximum number of results per page (10-500)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
default="",
|
||||
placeholder="Enter pagination token",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
tweet_ids: list[str] = SchemaField(description="All Tweet IDs")
|
||||
tweet_texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="53e5cf8e-a630-11ef-ba85-df6d666fa5d5",
|
||||
description="This block searches all public Tweets in Twitter history.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterSearchRecentTweetsBlock.Input,
|
||||
output_schema=TwitterSearchRecentTweetsBlock.Output,
|
||||
test_input={
|
||||
"query": "from:twitterapi #twitterapi",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": None,
|
||||
"until_id": None,
|
||||
"sort_order": None,
|
||||
"pagination": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
(
|
||||
"tweet_texts",
|
||||
[
|
||||
"Looking to get started with the Twitter API but new to APIs in general?",
|
||||
"Thanks to everyone who joined and made today a great session!",
|
||||
],
|
||||
),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{
|
||||
"id": "1373001119480344583",
|
||||
"text": "Looking to get started with the Twitter API but new to APIs in general?",
|
||||
},
|
||||
{
|
||||
"id": "1372627771717869568",
|
||||
"text": "Thanks to everyone who joined and made today a great session!",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_tweets": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
[
|
||||
"Looking to get started with the Twitter API but new to APIs in general?",
|
||||
"Thanks to everyone who joined and made today a great session!",
|
||||
],
|
||||
[
|
||||
{
|
||||
"id": "1373001119480344583",
|
||||
"text": "Looking to get started with the Twitter API but new to APIs in general?",
|
||||
},
|
||||
{
|
||||
"id": "1372627771717869568",
|
||||
"text": "Thanks to everyone who joined and made today a great session!",
|
||||
},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
query: str,
|
||||
max_results: int,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
# Building common params
|
||||
params = (
|
||||
TweetSearchBuilder()
|
||||
.add_query(query)
|
||||
.add_pagination(max_results, pagination)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.search_recent_tweets(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
return tweet_ids, tweet_texts, data, included, meta, next_token
|
||||
|
||||
raise Exception("No tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, data, included, meta, next_token = self.search_tweets(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "tweet_ids", ids
|
||||
if texts:
|
||||
yield "tweet_texts", texts
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,222 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExcludesFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetQuoteTweetsBlock(Block):
|
||||
"""
|
||||
Gets quote tweets for a specified tweet ID
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to get quotes for",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
exclude: TweetExcludesFilter | None = SchemaField(
|
||||
description="Types of tweets to exclude", advanced=True, default=None
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list = SchemaField(description="All Tweet IDs ")
|
||||
texts: list = SchemaField(description="All Tweet texts")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9fbdd208-a630-11ef-9b97-ab7a3a695ca3",
|
||||
description="This block gets quote tweets for a specific tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetQuoteTweetsBlock.Input,
|
||||
output_schema=TwitterGetQuoteTweetsBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("texts", ["Tweet 1", "Tweet 2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_quote_tweets": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["Tweet 1", "Tweet 2"],
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_quote_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
max_results: int | None,
|
||||
exclude: TweetExcludesFilter | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": tweet_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"exclude": None if exclude == TweetExcludesFilter() else exclude,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_quote_tweets(**params))
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
return tweet_ids, tweet_texts, data, included, meta, next_token
|
||||
|
||||
raise Exception("No quote tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, data, included, meta, next_token = self.get_quote_tweets(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.max_results,
|
||||
input_data.exclude,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,363 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterRetweetBlock(Block):
|
||||
"""
|
||||
Retweets a tweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to retweet",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the retweet was successful")
|
||||
error: str = SchemaField(description="Error message if the retweet failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd7b8d3a-a630-11ef-be96-6f4aa4c3c4f4",
|
||||
description="This block retweets a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRetweetBlock.Input,
|
||||
output_schema=TwitterRetweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"retweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def retweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.retweet(
|
||||
tweet_id=tweet_id,
|
||||
user_auth=False,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.retweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterRemoveRetweetBlock(Block):
|
||||
"""
|
||||
Removes a retweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to remove retweet",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the retweet was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the removal failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6e663f0-a630-11ef-a7f0-8b9b0c542ff8",
|
||||
description="This block removes a retweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRemoveRetweetBlock.Input,
|
||||
output_schema=TwitterRemoveRetweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"remove_retweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_retweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unretweet(
|
||||
source_tweet_id=tweet_id,
|
||||
user_auth=False,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_retweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetRetweetersBlock(Block):
|
||||
"""
|
||||
Gets information about who has retweeted a tweet
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to get retweeters for",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
default=10,
|
||||
placeholder="Enter max results",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list = SchemaField(description="List of user ids who retweeted")
|
||||
names: list = SchemaField(description="List of user names who retweeted")
|
||||
usernames: list = SchemaField(
|
||||
description="List of user usernames who retweeted"
|
||||
)
|
||||
next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ad7aa6fa-a630-11ef-a6b0-e7ca640aa030",
|
||||
description="This block gets information about who has retweeted a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetRetweetersBlock.Input,
|
||||
output_schema=TwitterGetRetweetersBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 1,
|
||||
"pagination_token": "",
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345"]),
|
||||
("names", ["Test User"]),
|
||||
("usernames", ["testuser"]),
|
||||
(
|
||||
"data",
|
||||
[{"id": "12345", "name": "Test User", "username": "testuser"}],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_retweeters": lambda *args, **kwargs: (
|
||||
[{"id": "12345", "name": "Test User", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
["12345"],
|
||||
["Test User"],
|
||||
["testuser"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_retweeters(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": tweet_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_retweeters(**params))
|
||||
|
||||
meta = {}
|
||||
ids = []
|
||||
names = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
ids = [str(user.id) for user in response.data]
|
||||
names = [user.name for user in response.data]
|
||||
usernames = [user.username for user in response.data]
|
||||
return data, included, meta, ids, names, usernames, next_token
|
||||
|
||||
raise Exception("No retweeters found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, meta, ids, names, usernames, next_token = (
|
||||
self.get_retweeters(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if names:
|
||||
yield "names", names
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,757 @@
|
|||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
TweetDurationBuilder,
|
||||
TweetExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetTimeWindowInputs,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetUserMentionsBlock(Block):
|
||||
"""
|
||||
Returns Tweets where a single user is mentioned, just put that user id
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="Unique identifier of the user for whom to return Tweets mentioning the user",
|
||||
placeholder="Enter user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of tweets to retrieve (5-100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that mentioned the user"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that mentioned the user"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e01c890c-a630-11ef-9e20-37da24888bd0",
|
||||
description="This block retrieves Tweets mentioning a specific user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUserMentionsBlock.Input,
|
||||
output_schema=TwitterGetUserMentionsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": "",
|
||||
"until_id": "",
|
||||
"sort_order": None,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
("texts", ["Test mention 1", "Test mention 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test mention 1"},
|
||||
{"id": "1372627771717869568", "text": "Test mention 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_mentions": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
["Test mention 1", "Test mention 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test mention 1"},
|
||||
{"id": "1372627771717869568", "text": "Test mention 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_mentions(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": None if pagination == "" else pagination,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_users_mentions(**params),
|
||||
)
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
meta = response.meta or {}
|
||||
next_token = meta.get("next_token", "")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
user_ids = [str(user["id"]) for user in included["users"]]
|
||||
user_names = [user["username"] for user in included["users"]]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_mentions(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetHomeTimelineBlock(Block):
|
||||
"""
|
||||
Returns a collection of the most recent Tweets and Retweets posted by you and users you follow
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of tweets to retrieve (5-100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d222a070-a630-11ef-a18a-3f52f76c6962",
|
||||
description="This block retrieves the authenticated user's home timeline.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetHomeTimelineBlock.Input,
|
||||
output_schema=TwitterGetHomeTimelineBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": None,
|
||||
"until_id": None,
|
||||
"sort_order": None,
|
||||
"pagination_token": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
("texts", ["Test tweet 1", "Test tweet 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_timeline": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
["Test tweet 1", "Test tweet 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_timeline(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": None if pagination == "" else pagination,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_home_timeline(**params),
|
||||
)
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
meta = response.meta or {}
|
||||
next_token = meta.get("next_token", "")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
user_ids = [str(user["id"]) for user in included["users"]]
|
||||
user_names = [user["username"] for user in included["users"]]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_timeline(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetUserTweetsBlock(Block):
|
||||
"""
|
||||
Returns Tweets composed by a single user, specified by the requested user ID
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="Unique identifier of the Twitter account (user ID) for whom to return results",
|
||||
placeholder="Enter user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of tweets to retrieve (5-100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c44c3ef2-a630-11ef-9ff7-eb7b5ea3a5cb",
|
||||
description="This block retrieves Tweets composed by a single user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUserTweetsBlock.Input,
|
||||
output_schema=TwitterGetUserTweetsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": None,
|
||||
"until_id": None,
|
||||
"sort_order": None,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
("texts", ["Test tweet 1", "Test tweet 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_user_tweets": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
["Test tweet 1", "Test tweet 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": None if pagination == "" else pagination,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_users_tweets(**params),
|
||||
)
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
meta = response.meta or {}
|
||||
next_token = meta.get("next_token", "")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
user_ids = [str(user["id"]) for user in included["users"]]
|
||||
user_names = [user["username"] for user in included["users"]]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_user_tweets(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,361 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetTweetBlock(Block):
|
||||
"""
|
||||
Returns information about a single Tweet specified by the requested ID
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="Unique identifier of the Tweet to request (ex: 1460323737035677698)",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
id: str = SchemaField(description="Tweet ID")
|
||||
text: str = SchemaField(description="Tweet text")
|
||||
userId: str = SchemaField(description="ID of the tweet author")
|
||||
userName: str = SchemaField(description="Username of the tweet author")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: dict = SchemaField(description="Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the tweet")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f5155c3a-a630-11ef-9cc1-a309988b4d92",
|
||||
description="This block retrieves information about a specific Tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetTweetBlock.Input,
|
||||
output_schema=TwitterGetTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1460323737035677698",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "1460323737035677698"),
|
||||
("text", "Test tweet content"),
|
||||
("userId", "12345"),
|
||||
("userName", "testuser"),
|
||||
("data", {"id": "1460323737035677698", "text": "Test tweet content"}),
|
||||
("included", {"users": [{"id": "12345", "username": "testuser"}]}),
|
||||
("meta", {"result_count": 1}),
|
||||
],
|
||||
test_mock={
|
||||
"get_tweet": lambda *args, **kwargs: (
|
||||
{"id": "1460323737035677698", "text": "Test tweet content"},
|
||||
{"users": [{"id": "12345", "username": "testuser"}]},
|
||||
{"result_count": 1},
|
||||
"12345",
|
||||
"testuser",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
params = {"id": tweet_id, "user_auth": False}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_tweet(**params))
|
||||
|
||||
meta = {}
|
||||
user_id = ""
|
||||
user_name = ""
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_dict(response.data)
|
||||
|
||||
if included and "users" in included:
|
||||
user_id = str(included["users"][0]["id"])
|
||||
user_name = included["users"][0]["username"]
|
||||
|
||||
if response.data:
|
||||
return data, included, meta, user_id, user_name
|
||||
|
||||
raise Exception("Tweet not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
|
||||
tweet_data, included, meta, user_id, user_name = self.get_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
yield "id", str(tweet_data["id"])
|
||||
yield "text", tweet_data["text"]
|
||||
if user_id:
|
||||
yield "userId", user_id
|
||||
if user_name:
|
||||
yield "userName", user_name
|
||||
yield "data", tweet_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetTweetsBlock(Block):
|
||||
"""
|
||||
Returns information about multiple Tweets specified by the requested IDs
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_ids: list[str] = SchemaField(
|
||||
description="List of Tweet IDs to request (up to 100)",
|
||||
placeholder="Enter tweet IDs",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="All Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the tweets")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e7cc5420-a630-11ef-bfaf-13bdd8096a51",
|
||||
description="This block retrieves information about multiple Tweets.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetTweetsBlock.Input,
|
||||
output_schema=TwitterGetTweetsBlock.Output,
|
||||
test_input={
|
||||
"tweet_ids": ["1460323737035677698"],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1460323737035677698"]),
|
||||
("texts", ["Test tweet content"]),
|
||||
("userIds", ["67890"]),
|
||||
("userNames", ["testuser1"]),
|
||||
("data", [{"id": "1460323737035677698", "text": "Test tweet content"}]),
|
||||
("included", {"users": [{"id": "67890", "username": "testuser1"}]}),
|
||||
("meta", {"result_count": 1}),
|
||||
],
|
||||
test_mock={
|
||||
"get_tweets": lambda *args, **kwargs: (
|
||||
["1460323737035677698"], # ids
|
||||
["Test tweet content"], # texts
|
||||
["67890"], # user_ids
|
||||
["testuser1"], # user_names
|
||||
[
|
||||
{"id": "1460323737035677698", "text": "Test tweet content"}
|
||||
], # data
|
||||
{"users": [{"id": "67890", "username": "testuser1"}]}, # included
|
||||
{"result_count": 1}, # meta
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_ids: list[str],
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
params = {"ids": tweet_ids, "user_auth": False}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_tweets(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
meta = {}
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if included and "users" in included:
|
||||
for user in included["users"]:
|
||||
user_ids.append(str(user["id"]))
|
||||
user_names.append(user["username"])
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
return tweet_ids, tweet_texts, user_ids, user_names, data, included, meta
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta = self.get_tweets(
|
||||
credentials,
|
||||
input_data.tweet_ids,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,305 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import IncludesSerializer
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnblockUserBlock(Block):
|
||||
"""
|
||||
Unblock a specific user on Twitter. The request succeeds with no action when the user sends a request to a user they're not blocking or have already unblocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["block.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unblock",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unblock was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f1b6570-a631-11ef-a3ea-230cbe9650dd",
|
||||
description="This block unblocks a specific user on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnblockUserBlock.Input,
|
||||
output_schema=TwitterUnblockUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unblock_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unblock_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unblock(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unblock_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetBlockedUsersBlock(Block):
|
||||
"""
|
||||
Get a list of users who are blocked by the authenticating user
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access", "block.read"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-1000, default 100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for retrieving next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
user_ids: list[str] = SchemaField(description="List of blocked user IDs")
|
||||
usernames_: list[str] = SchemaField(description="List of blocked usernames")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="05f409e8-a631-11ef-ae89-93de863ee30d",
|
||||
description="This block retrieves a list of users blocked by the authenticating user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetBlockedUsersBlock.Input,
|
||||
output_schema=TwitterGetBlockedUsersBlock.Output,
|
||||
test_input={
|
||||
"max_results": 10,
|
||||
"pagination_token": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("user_ids", ["12345", "67890"]),
|
||||
("usernames_", ["testuser1", "testuser2"]),
|
||||
],
|
||||
test_mock={
|
||||
"get_blocked_users": lambda *args, **kwargs: (
|
||||
{}, # included
|
||||
{}, # meta
|
||||
["12345", "67890"], # user_ids
|
||||
["testuser1", "testuser2"], # usernames
|
||||
None, # next_token
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_blocked_users(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_blocked(**params))
|
||||
|
||||
meta = {}
|
||||
user_ids = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
for user in response.data:
|
||||
user_ids.append(str(user.id))
|
||||
usernames.append(user.username)
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
if "next_token" in meta:
|
||||
next_token = meta["next_token"]
|
||||
|
||||
if user_ids and usernames:
|
||||
return included, meta, user_ids, usernames, next_token
|
||||
else:
|
||||
raise tweepy.TweepyException("No blocked users found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
included, meta, user_ids, usernames, next_token = self.get_blocked_users(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if user_ids:
|
||||
yield "user_ids", user_ids
|
||||
if usernames:
|
||||
yield "usernames_", usernames
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterBlockUserBlock(Block):
|
||||
"""
|
||||
Block a specific user on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["block.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to block",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the block was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fc258b94-a630-11ef-abc3-df050b75b816",
|
||||
description="This block blocks a specific user on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterBlockUserBlock.Input,
|
||||
output_schema=TwitterBlockUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"block_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def block_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.block(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.block_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,510 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnfollowUserBlock(Block):
|
||||
"""
|
||||
Allows a user to unfollow another user specified by target user ID.
|
||||
The request succeeds with no action when the authenticated user sends a request to a user they're not following or have already unfollowed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "follows.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unfollow",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the unfollow action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37e386a4-a631-11ef-b7bd-b78204b35fa4",
|
||||
description="This block unfollows a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnfollowUserBlock.Input,
|
||||
output_schema=TwitterUnfollowUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unfollow_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unfollow_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unfollow_user(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unfollow_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterFollowUserBlock(Block):
|
||||
"""
|
||||
Allows a user to follow another user specified by target user ID. If the target user does not have public Tweets,
|
||||
this endpoint will send a follow request. The request succeeds with no action when the authenticated user sends a
|
||||
request to a user they're already following, or if they're sending a follower request to a user that does not have
|
||||
public Tweets.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "follows.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to follow",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the follow action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1aae6a5e-a631-11ef-a090-435900c6d429",
|
||||
description="This block follows a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterFollowUserBlock.Input,
|
||||
output_schema=TwitterFollowUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"follow_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def follow_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.follow_user(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.follow_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetFollowersBlock(Block):
|
||||
"""
|
||||
Retrieves a list of followers for a specified Twitter user ID
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access", "follows.read"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID whose followers you would like to retrieve",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-1000, default 100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for retrieving next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of follower user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of follower usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="Complete user data for followers")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30f66410-a631-11ef-8fe7-d7f888b4f43c",
|
||||
description="This block retrieves followers of a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetFollowersBlock.Input,
|
||||
output_schema=TwitterGetFollowersBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"max_results": 1,
|
||||
"pagination_token": "",
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1234567890"]),
|
||||
("usernames", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "username": "testuser"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_followers": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_followers(
|
||||
credentials: TwitterCredentials,
|
||||
target_user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": target_user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_users_followers(**params))
|
||||
|
||||
meta = {}
|
||||
follower_ids = []
|
||||
follower_usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
follower_ids = [str(user.id) for user in response.data]
|
||||
follower_usernames = [user.username for user in response.data]
|
||||
|
||||
return (
|
||||
follower_ids,
|
||||
follower_usernames,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("Followers not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, includes, meta, next_token = self.get_followers(
|
||||
credentials,
|
||||
input_data.target_user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetFollowingBlock(Block):
|
||||
"""
|
||||
Retrieves a list of users that a specified Twitter user ID is following
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access", "follows.read"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID whose following you would like to retrieve",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-1000, default 100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for retrieving next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of following user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of following usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="Complete user data for following")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="264a399c-a631-11ef-a97d-bfde4ca91173",
|
||||
description="This block retrieves the users that a specified Twitter user is following.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetFollowingBlock.Input,
|
||||
output_schema=TwitterGetFollowingBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1234567890"]),
|
||||
("usernames", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "username": "testuser"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_following": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_following(
|
||||
credentials: TwitterCredentials,
|
||||
target_user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": target_user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_users_following(**params))
|
||||
|
||||
meta = {}
|
||||
following_ids = []
|
||||
following_usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
following_ids = [str(user.id) for user in response.data]
|
||||
following_usernames = [user.username for user in response.data]
|
||||
|
||||
return (
|
||||
following_ids,
|
||||
following_usernames,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("Following not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, includes, meta, next_token = self.get_following(
|
||||
credentials,
|
||||
input_data.target_user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,328 @@
|
|||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnmuteUserBlock(Block):
|
||||
"""
|
||||
Allows a user to unmute another user specified by target user ID.
|
||||
The request succeeds with no action when the user sends a request to a user they're not muting or have already unmuted.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unmute",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the unmute action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="40458504-a631-11ef-940b-eff92be55422",
|
||||
description="This block unmutes a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnmuteUserBlock.Input,
|
||||
output_schema=TwitterUnmuteUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unmute_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unmute_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unmute(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unmute_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetMutedUsersBlock(Block):
|
||||
"""
|
||||
Returns a list of users who are muted by the authenticating user
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="The maximum number of results to be returned per page (1-1000). Default is 100.",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token to request next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of muted user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of muted usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="Complete user data for muted users")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="475024da-a631-11ef-9ccd-f724b8b03cda",
|
||||
description="This block gets a list of users muted by the authenticating user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetMutedUsersBlock.Input,
|
||||
output_schema=TwitterGetMutedUsersBlock.Output,
|
||||
test_input={
|
||||
"max_results": 2,
|
||||
"pagination_token": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("usernames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_muted_users": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_muted_users(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_muted(**params))
|
||||
|
||||
meta = {}
|
||||
user_ids = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
user_ids = [str(item.id) for item in response.data]
|
||||
usernames = [item.username for item in response.data]
|
||||
|
||||
return user_ids, usernames, data, included, meta, next_token
|
||||
|
||||
raise Exception("Muted users not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, includes, meta, next_token = self.get_muted_users(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterMuteUserBlock(Block):
|
||||
"""
|
||||
Allows a user to mute another user specified by target user ID
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to mute",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the mute action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4d1919d0-a631-11ef-90ab-3b73af9ce8f1",
|
||||
description="This block mutes a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterMuteUserBlock.Input,
|
||||
output_schema=TwitterMuteUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"mute_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mute_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.mute(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.mute_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -0,0 +1,383 @@
|
|||
from typing import Literal, Union, cast
|
||||
|
||||
import tweepy
|
||||
from pydantic import BaseModel
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class UserId(BaseModel):
|
||||
discriminator: Literal["user_id"]
|
||||
user_id: str = SchemaField(description="The ID of the user to lookup", default="")
|
||||
|
||||
|
||||
class Username(BaseModel):
|
||||
discriminator: Literal["username"]
|
||||
username: str = SchemaField(
|
||||
description="The Twitter username (handle) of the user", default=""
|
||||
)
|
||||
|
||||
|
||||
class TwitterGetUserBlock(Block):
|
||||
"""
|
||||
Gets information about a single Twitter user specified by ID or username
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access"]
|
||||
)
|
||||
|
||||
identifier: Union[UserId, Username] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Choose whether to identify the user by their unique Twitter ID or by their username",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
id: str = SchemaField(description="User ID")
|
||||
username_: str = SchemaField(description="User username")
|
||||
name_: str = SchemaField(description="User name")
|
||||
|
||||
# Complete outputs
|
||||
data: dict = SchemaField(description="Complete user data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5446db8e-a631-11ef-812a-cf315d373ee9",
|
||||
description="This block retrieves information about a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUserBlock.Input,
|
||||
output_schema=TwitterGetUserBlock.Output,
|
||||
test_input={
|
||||
"identifier": {"discriminator": "username", "username": "twitter"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "783214"),
|
||||
("username_", "twitter"),
|
||||
("name_", "Twitter"),
|
||||
(
|
||||
"data",
|
||||
{
|
||||
"user": {
|
||||
"id": "783214",
|
||||
"username": "twitter",
|
||||
"name": "Twitter",
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_user": lambda *args, **kwargs: (
|
||||
{
|
||||
"user": {
|
||||
"id": "783214",
|
||||
"username": "twitter",
|
||||
"name": "Twitter",
|
||||
}
|
||||
},
|
||||
{},
|
||||
"twitter",
|
||||
"783214",
|
||||
"Twitter",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user(
|
||||
credentials: TwitterCredentials,
|
||||
identifier: Union[UserId, Username],
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": identifier.user_id if isinstance(identifier, UserId) else None,
|
||||
"username": (
|
||||
identifier.username if isinstance(identifier, Username) else None
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_user(**params))
|
||||
|
||||
username = ""
|
||||
id = ""
|
||||
name = ""
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_dict(response.data)
|
||||
|
||||
if response.data:
|
||||
username = response.data.username
|
||||
id = str(response.data.id)
|
||||
name = response.data.name
|
||||
|
||||
if username and id:
|
||||
return data, included, username, id, name
|
||||
else:
|
||||
raise tweepy.TweepyException("User not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, username, id, name = self.get_user(
|
||||
credentials,
|
||||
input_data.identifier,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if id:
|
||||
yield "id", id
|
||||
if username:
|
||||
yield "username_", username
|
||||
if name:
|
||||
yield "name_", name
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class UserIdList(BaseModel):
|
||||
discriminator: Literal["user_id_list"]
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup (max 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class UsernameList(BaseModel):
|
||||
discriminator: Literal["username_list"]
|
||||
usernames: list[str] = SchemaField(
|
||||
description="List of Twitter usernames/handles to lookup (max 100)",
|
||||
placeholder="Enter usernames",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class TwitterGetUsersBlock(Block):
|
||||
"""
|
||||
Gets information about multiple Twitter users specified by IDs or usernames
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access"]
|
||||
)
|
||||
|
||||
identifier: Union[UserIdList, UsernameList] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Choose whether to identify users by their unique Twitter IDs or by their usernames",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
ids: list[str] = SchemaField(description="User IDs")
|
||||
usernames_: list[str] = SchemaField(description="User usernames")
|
||||
names_: list[str] = SchemaField(description="User names")
|
||||
|
||||
# Complete outputs
|
||||
data: list[dict] = SchemaField(description="Complete users data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5abc857c-a631-11ef-8cfc-f7b79354f7a1",
|
||||
description="This block retrieves information about multiple Twitter users.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUsersBlock.Input,
|
||||
output_schema=TwitterGetUsersBlock.Output,
|
||||
test_input={
|
||||
"identifier": {
|
||||
"discriminator": "username_list",
|
||||
"usernames": ["twitter", "twitterdev"],
|
||||
},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["783214", "2244994945"]),
|
||||
("usernames_", ["twitter", "twitterdev"]),
|
||||
("names_", ["Twitter", "Twitter Dev"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "783214", "username": "twitter", "name": "Twitter"},
|
||||
{
|
||||
"id": "2244994945",
|
||||
"username": "twitterdev",
|
||||
"name": "Twitter Dev",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_users": lambda *args, **kwargs: (
|
||||
[
|
||||
{"id": "783214", "username": "twitter", "name": "Twitter"},
|
||||
{
|
||||
"id": "2244994945",
|
||||
"username": "twitterdev",
|
||||
"name": "Twitter Dev",
|
||||
},
|
||||
],
|
||||
{},
|
||||
["twitter", "twitterdev"],
|
||||
["783214", "2244994945"],
|
||||
["Twitter", "Twitter Dev"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_users(
|
||||
credentials: TwitterCredentials,
|
||||
identifier: Union[UserIdList, UsernameList],
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"ids": (
|
||||
",".join(identifier.user_ids)
|
||||
if isinstance(identifier, UserIdList)
|
||||
else None
|
||||
),
|
||||
"usernames": (
|
||||
",".join(identifier.usernames)
|
||||
if isinstance(identifier, UsernameList)
|
||||
else None
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_users(**params))
|
||||
|
||||
usernames = []
|
||||
ids = []
|
||||
names = []
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
for user in response.data:
|
||||
usernames.append(user.username)
|
||||
ids.append(str(user.id))
|
||||
names.append(user.name)
|
||||
|
||||
if usernames and ids:
|
||||
return data, included, usernames, ids, names
|
||||
else:
|
||||
raise tweepy.TweepyException("Users not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, usernames, ids, names = self.get_users(
|
||||
credentials,
|
||||
input_data.identifier,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames_", usernames
|
||||
if names:
|
||||
yield "names_", names
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
|
@ -61,6 +61,9 @@ class BlockCategory(Enum):
|
|||
HARDWARE = "Block that interacts with hardware."
|
||||
AGENT = "Block that interacts with other agents."
|
||||
CRM = "Block that interacts with CRM services."
|
||||
SAFETY = (
|
||||
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||
)
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
|
|
@ -51,6 +51,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
|||
LlmModel.LLAMA3_1_405B: 1,
|
||||
LlmModel.LLAMA3_1_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
|
|
|
@ -136,6 +136,7 @@ async def create_graph_execution(
|
|||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
|
@ -163,6 +164,7 @@ async def create_graph_execution(
|
|||
]
|
||||
},
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
|
|
@ -6,7 +6,13 @@ from datetime import datetime, timezone
|
|||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
import prisma
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.models import (
|
||||
AgentGraph,
|
||||
AgentGraphExecution,
|
||||
AgentNode,
|
||||
AgentNodeLink,
|
||||
StoreListingVersion,
|
||||
)
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
|
@ -529,7 +535,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
|
|||
async def get_graph(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
) -> GraphModel | None:
|
||||
|
@ -543,21 +548,36 @@ async def get_graph(
|
|||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
}
|
||||
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
elif not template:
|
||||
else:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# TODO: Fix hack workaround to get adding store agents to work
|
||||
if user_id is not None and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
return GraphModel.from_db(graph, for_export) if graph else None
|
||||
|
||||
# The Graph has to be owned by the user or a store listing.
|
||||
if (
|
||||
graph is None
|
||||
or graph.userId != user_id
|
||||
and not (
|
||||
await StoreListingVersion.prisma().find_first(
|
||||
where=prisma.types.StoreListingVersionWhereInput(
|
||||
agentId=graph_id,
|
||||
agentVersion=version or graph.version,
|
||||
isDeleted=False,
|
||||
StoreListing={"is": {"isApproved": True}},
|
||||
)
|
||||
)
|
||||
)
|
||||
):
|
||||
return None
|
||||
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
|
@ -611,9 +631,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
|||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph, user_id)
|
||||
|
||||
if created_graph := await get_graph(
|
||||
graph.id, graph.version, graph.is_template, user_id=user_id
|
||||
):
|
||||
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
|
|
@ -139,6 +139,8 @@ def SchemaField(
|
|||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: list[str] | None = None,
|
||||
image_upload: Optional[bool] = None,
|
||||
image_output: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
|
@ -154,6 +156,8 @@ def SchemaField(
|
|||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
"image_upload": image_upload,
|
||||
"image_output": image_output,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
@ -222,6 +226,7 @@ class OAuthState(BaseModel):
|
|||
token: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
code_verifier: Optional[str] = None
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
scopes: list[str]
|
||||
|
||||
|
|
|
@ -780,7 +780,8 @@ class ExecutionManager(AppService):
|
|||
graph_id: str,
|
||||
data: BlockInput,
|
||||
user_id: str,
|
||||
graph_version: int | None = None,
|
||||
graph_version: int,
|
||||
preset_id: str | None = None,
|
||||
) -> GraphExecutionEntry:
|
||||
graph: GraphModel | None = self.db_client.get_graph(
|
||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||
|
@ -829,6 +830,7 @@ class ExecutionManager(AppService):
|
|||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
starting_node_execs = []
|
||||
|
|
|
@ -63,7 +63,10 @@ def execute_graph(**kwargs):
|
|||
try:
|
||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||
get_execution_client().add_execution(
|
||||
args.graph_id, args.input_data, args.user_id
|
||||
graph_id=args.graph_id,
|
||||
data=args.input_data,
|
||||
user_id=args.user_id,
|
||||
graph_version=args.graph_version,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
@ -91,6 +93,34 @@ open_router_credentials = APIKeyCredentials(
|
|||
title="Use Credits for Open Router",
|
||||
expires_at=None,
|
||||
)
|
||||
fal_credentials = APIKeyCredentials(
|
||||
id="6c0f5bd0-9008-4638-9d79-4b40b631803e",
|
||||
provider="fal",
|
||||
api_key=SecretStr(settings.secrets.fal_api_key),
|
||||
title="Use Credits for FAL",
|
||||
expires_at=None,
|
||||
)
|
||||
exa_credentials = APIKeyCredentials(
|
||||
id="96153e04-9c6c-4486-895f-5bb683b1ecec",
|
||||
provider="exa",
|
||||
api_key=SecretStr(settings.secrets.exa_api_key),
|
||||
title="Use Credits for Exa search",
|
||||
expires_at=None,
|
||||
)
|
||||
e2b_credentials = APIKeyCredentials(
|
||||
id="78d19fd7-4d59-4a16-8277-3ce310acf2b7",
|
||||
provider="e2b",
|
||||
api_key=SecretStr(settings.secrets.e2b_api_key),
|
||||
title="Use Credits for E2B",
|
||||
expires_at=None,
|
||||
)
|
||||
nvidia_credentials = APIKeyCredentials(
|
||||
id="96b83908-2789-4dec-9968-18f0ece4ceb3",
|
||||
provider="nvidia",
|
||||
api_key=SecretStr(settings.secrets.nvidia_api_key),
|
||||
title="Use Credits for Nvidia",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CREDENTIALS = [
|
||||
|
@ -104,6 +134,10 @@ DEFAULT_CREDENTIALS = [
|
|||
jina_credentials,
|
||||
unreal_credentials,
|
||||
open_router_credentials,
|
||||
fal_credentials,
|
||||
exa_credentials,
|
||||
e2b_credentials,
|
||||
nvidia_credentials,
|
||||
]
|
||||
|
||||
|
||||
|
@ -155,6 +189,14 @@ class IntegrationCredentialsStore:
|
|||
all_credentials.append(unreal_credentials)
|
||||
if settings.secrets.open_router_api_key:
|
||||
all_credentials.append(open_router_credentials)
|
||||
if settings.secrets.fal_api_key:
|
||||
all_credentials.append(fal_credentials)
|
||||
if settings.secrets.exa_api_key:
|
||||
all_credentials.append(exa_credentials)
|
||||
if settings.secrets.e2b_api_key:
|
||||
all_credentials.append(e2b_credentials)
|
||||
if settings.secrets.nvidia_api_key:
|
||||
all_credentials.append(nvidia_credentials)
|
||||
return all_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
|
@ -210,18 +252,24 @@ class IntegrationCredentialsStore:
|
|||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
|
||||
def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
|
||||
) -> tuple[str, str]:
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
|
||||
(code_challenge, code_verifier) = self._generate_code_challenge()
|
||||
|
||||
state = OAuthState(
|
||||
token=token,
|
||||
provider=provider,
|
||||
code_verifier=code_verifier,
|
||||
expires_at=int(expires_at.timestamp()),
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
with self.locked_user_integrations(user_id):
|
||||
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
oauth_states.append(state)
|
||||
|
@ -231,39 +279,21 @@ class IntegrationCredentialsStore:
|
|||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
return token
|
||||
return token, code_challenge
|
||||
|
||||
def get_any_valid_scopes_from_state_token(
|
||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||
"""
|
||||
Generate code challenge using SHA256 from the code verifier.
|
||||
Currently only SHA256 is supported.(In future if we want to support more methods we can add them here)
|
||||
"""
|
||||
code_verifier = secrets.token_urlsafe(128)
|
||||
sha256_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(sha256_hash).decode("utf-8")
|
||||
return code_challenge.replace("=", ""), code_verifier
|
||||
|
||||
def verify_state_token(
|
||||
self, user_id: str, token: str, provider: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the valid scopes from the OAuth state token. This will return any valid scopes
|
||||
from any OAuth state token for the given provider. If no valid scopes are found,
|
||||
an empty list is returned. DO NOT RELY ON THIS TOKEN TO AUTHENTICATE A USER, AS IT
|
||||
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
|
||||
THE CODE FOR TOKENS.
|
||||
"""
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state.token == token
|
||||
and state.provider == provider
|
||||
and state.expires_at > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
return valid_state.scopes
|
||||
|
||||
return []
|
||||
|
||||
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
) -> Optional[OAuthState]:
|
||||
with self.locked_user_integrations(user_id):
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
|
@ -285,9 +315,9 @@ class IntegrationCredentialsStore:
|
|||
oauth_states.remove(valid_state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
self.db_manager.update_user_integrations(user_id, user_integrations)
|
||||
return True
|
||||
return valid_state
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
def _set_user_integration_creds(
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
|
|||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
|
@ -15,6 +16,7 @@ HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
|||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
]
|
||||
}
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
@ -23,7 +23,9 @@ class BaseOAuthHandler(ABC):
|
|||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler3]
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
# --8<-- [end:BaseOAuthHandler3]
|
||||
"""Constructs a login URL that the user can be redirected to"""
|
||||
...
|
||||
|
@ -31,7 +33,7 @@ class BaseOAuthHandler(ABC):
|
|||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler4]
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
# --8<-- [end:BaseOAuthHandler4]
|
||||
"""Exchanges the acquired authorization code from login for a set of tokens"""
|
||||
|
|
|
@ -34,7 +34,9 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
|||
self.token_url = "https://github.com/login/oauth/access_token"
|
||||
self.revoke_url = "https://api.github.com/applications/{client_id}/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
|
@ -44,7 +46,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
|||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.auth.external_account_authorized_user import (
|
||||
Credentials as ExternalAccountCredentials,
|
||||
|
@ -38,7 +39,9 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
|||
self.token_uri = "https://oauth2.googleapis.com/token"
|
||||
self.revoke_uri = "https://oauth2.googleapis.com/revoke"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
|
||||
logger.debug(f"Setting up OAuth flow with scopes: {all_scopes}")
|
||||
flow = self._setup_oauth_flow(all_scopes)
|
||||
|
@ -52,7 +55,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
|||
return authorization_url
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug(f"Exchanging code for tokens with scopes: {scopes}")
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
@ -26,7 +27,9 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
|||
self.auth_base_url = "https://api.notion.com/v1/oauth/authorize"
|
||||
self.token_url = "https://api.notion.com/v1/oauth/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
|
@ -37,7 +40,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
|||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str]
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"grant_type": "authorization_code",
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from backend.data.model import OAuth2Credentials, ProviderName
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
|
||||
|
||||
class TwitterOAuthHandler(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.TWITTER
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"tweet.read",
|
||||
"tweet.write",
|
||||
"tweet.moderate.write",
|
||||
"users.read",
|
||||
"follows.read",
|
||||
"follows.write",
|
||||
"offline.access",
|
||||
"space.read",
|
||||
"mute.read",
|
||||
"mute.write",
|
||||
"like.read",
|
||||
"like.write",
|
||||
"list.read",
|
||||
"list.write",
|
||||
"block.read",
|
||||
"block.write",
|
||||
"bookmark.read",
|
||||
"bookmark.write",
|
||||
]
|
||||
|
||||
AUTHORIZE_URL = "https://twitter.com/i/oauth2/authorize"
|
||||
TOKEN_URL = "https://api.x.com/2/oauth2/token"
|
||||
USERNAME_URL = "https://api.x.com/2/users/me"
|
||||
REVOKE_URL = "https://api.x.com/2/oauth2/revoke"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate Twitter OAuth 2.0 authorization URL"""
|
||||
# scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
if code_challenge is None:
|
||||
raise ValueError("code_challenge is required for Twitter OAuth")
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(self.DEFAULT_SCOPES),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for access tokens"""
|
||||
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
data = {
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = requests.post(self.TOKEN_URL, headers=headers, data=data, auth=auth)
|
||||
response.raise_for_status()
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
username = self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens["expires_in"],
|
||||
refresh_token_expires_at=None,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
def _get_username(self, access_token: str) -> str:
|
||||
"""Get the username from the access token"""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
params = {"user.fields": "username"}
|
||||
|
||||
response = requests.get(
|
||||
f"{self.USERNAME_URL}?{urllib.parse.urlencode(params)}", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()["data"]["username"]
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
"""Refresh access tokens using refresh token"""
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
header = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = requests.post(self.TOKEN_URL, headers=header, data=data, auth=auth)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print("HTTP Error:", e)
|
||||
print("Response Content:", response.text)
|
||||
raise
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
username = self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
access_token_expires_at=int(time.time()) + tokens["expires_in"],
|
||||
scopes=credentials.scopes,
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revoke the access token"""
|
||||
|
||||
header = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = requests.post(self.REVOKE_URL, headers=header, data=data, auth=auth)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print("HTTP Error:", e)
|
||||
print("Response Content:", response.text)
|
||||
raise
|
||||
|
||||
return response.status_code == 200
|
|
@ -19,6 +19,7 @@ class ProviderName(str, Enum):
|
|||
JINA = "jina"
|
||||
MEDIUM = "medium"
|
||||
NOTION = "notion"
|
||||
NVIDIA = "nvidia"
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
OPENWEATHERMAP = "openweathermap"
|
||||
|
@ -27,5 +28,6 @@ class ProviderName(str, Enum):
|
|||
REPLICATE = "replicate"
|
||||
REVID = "revid"
|
||||
SLANT3D = "slant3d"
|
||||
TWITTER = "twitter"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
# --8<-- [end:ProviderName]
|
||||
|
|
|
@ -7,6 +7,6 @@ app_config = Config()
|
|||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name}"
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
)
|
||||
|
|
|
@ -60,11 +60,12 @@ def login(
|
|||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token = creds_manager.store.store_state_token(
|
||||
state_token, code_challenge = creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
|
||||
login_url = handler.get_login_url(requested_scopes, state_token)
|
||||
login_url = handler.get_login_url(
|
||||
requested_scopes, state_token, code_challenge=code_challenge
|
||||
)
|
||||
|
||||
return LoginResponse(login_url=login_url, state_token=state_token)
|
||||
|
||||
|
@ -92,19 +93,21 @@ def callback(
|
|||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
if not creds_manager.store.verify_state_token(user_id, state_token, provider):
|
||||
valid_state = creds_manager.store.verify_state_token(user_id, state_token, provider)
|
||||
|
||||
if not valid_state:
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
scopes = creds_manager.store.get_any_valid_scopes_from_state_token(
|
||||
user_id, state_token, provider
|
||||
)
|
||||
scopes = valid_state.scopes
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
|
||||
scopes = handler.handle_default_scopes(scopes)
|
||||
|
||||
credentials = handler.exchange_code_for_tokens(code, scopes)
|
||||
credentials = handler.exchange_code_for_tokens(
|
||||
code, scopes, valid_state.code_verifier
|
||||
)
|
||||
|
||||
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
|
||||
|
||||
# Check if the granted scopes are sufficient for the requested scopes
|
||||
|
@ -317,7 +320,8 @@ async def webhook_ingress_generic(
|
|||
continue
|
||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||
executor.add_execution(
|
||||
node.graph_id,
|
||||
graph_id=node.graph_id,
|
||||
graph_version=node.graph_version,
|
||||
data={f"webhook_{webhook_id}_payload": payload},
|
||||
user_id=webhook.user_id,
|
||||
)
|
||||
|
|
|
@ -56,3 +56,18 @@ class SetGraphActiveVersion(pydantic.BaseModel):
|
|||
|
||||
class UpdatePermissionsRequest(pydantic.BaseModel):
|
||||
permissions: List[APIKeyPermission]
|
||||
|
||||
|
||||
class Pagination(pydantic.BaseModel):
|
||||
total_items: int = pydantic.Field(
|
||||
description="Total number of items.", examples=[42]
|
||||
)
|
||||
total_pages: int = pydantic.Field(
|
||||
description="Total number of pages.", examples=[2]
|
||||
)
|
||||
current_page: int = pydantic.Field(
|
||||
description="Current_page page number.", examples=[1]
|
||||
)
|
||||
page_size: int = pydantic.Field(
|
||||
description="Number of items per page.", examples=[25]
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@ import contextlib
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import starlette.middleware.cors
|
||||
|
@ -16,7 +17,9 @@ import backend.data.db
|
|||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.store.model
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
|
@ -117,9 +120,24 @@ class AgentServer(backend.util.service.AppProcess):
|
|||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: dict[typing.Any, typing.Any],
|
||||
user_id: str,
|
||||
):
|
||||
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
|
||||
return backend.server.routers.v1.execute_graph(
|
||||
graph_id, graph_version, node_input, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph(
|
||||
graph_id, user_id, graph_version
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_graph(
|
||||
|
@ -149,5 +167,71 @@ class AgentServer(backend.util.service.AppProcess):
|
|||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
|
||||
return await backend.server.v2.library.routes.presets.get_presets(
|
||||
user_id=user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_get_preset(preset_id: str, user_id: str):
|
||||
return await backend.server.v2.library.routes.presets.get_preset(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_preset(
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.create_preset(
|
||||
preset=preset, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_update_preset(
|
||||
preset_id: str,
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.update_preset(
|
||||
preset_id=preset_id, preset=preset, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_preset(preset_id: str, user_id: str):
|
||||
return await backend.server.v2.library.routes.presets.delete_preset(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
node_input: dict[typing.Any, typing.Any],
|
||||
user_id: str,
|
||||
):
|
||||
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
preset_id=preset_id,
|
||||
node_input=node_input,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_create_store_listing(
|
||||
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
|
||||
):
|
||||
return await backend.server.v2.store.routes.create_submission(request, user_id)
|
||||
|
||||
@staticmethod
|
||||
async def test_review_store_listing(
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: autogpt_libs.auth.models.User,
|
||||
):
|
||||
return await backend.server.v2.store.routes.review_submission(request, user)
|
||||
|
||||
def set_test_dependency_overrides(self, overrides: dict):
|
||||
app.dependency_overrides.update(overrides)
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing_extensions import Optional, TypedDict
|
|||
import backend.data.block
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.v2.library.db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.api_key import (
|
||||
|
@ -180,11 +181,6 @@ async def get_graph(
|
|||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@v1_router.get(
|
||||
path="/templates/{graph_id}/versions",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
|
@ -200,12 +196,11 @@ async def get_graph_all_versions(
|
|||
async def create_new_graph(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
return await do_create_graph(create_graph, user_id=user_id)
|
||||
|
||||
|
||||
async def do_create_graph(
|
||||
create_graph: CreateGraph,
|
||||
is_template: bool,
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
|
@ -217,7 +212,6 @@ async def do_create_graph(
|
|||
graph = await graph_db.get_graph(
|
||||
create_graph.template_id,
|
||||
create_graph.template_version,
|
||||
template=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not graph:
|
||||
|
@ -225,13 +219,18 @@ async def do_create_graph(
|
|||
400, detail=f"Template #{create_graph.template_id} not found"
|
||||
)
|
||||
graph.version = 1
|
||||
|
||||
# Create a library agent for the new graph
|
||||
await backend.server.v2.library.db.create_library_agent(
|
||||
graph.id,
|
||||
graph.version,
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Either graph or template_id must be provided."
|
||||
)
|
||||
|
||||
graph.is_template = is_template
|
||||
graph.is_active = not is_template
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
graph = await graph_db.create_graph(graph, user_id=user_id)
|
||||
|
@ -261,11 +260,6 @@ async def delete_graph(
|
|||
@v1_router.put(
|
||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||
)
|
||||
@v1_router.put(
|
||||
path="/templates/{graph_id}",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
|
@ -298,6 +292,11 @@ async def update_graph(
|
|||
|
||||
if new_graph_version.is_active:
|
||||
|
||||
# Keep the library agent up to date with the new active version
|
||||
await backend.server.v2.library.db.update_agent_version_in_library(
|
||||
user_id, graph.id, graph.version
|
||||
)
|
||||
|
||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||
return integration_creds_manager.get(user_id, credentials_id)
|
||||
|
||||
|
@ -353,6 +352,12 @@ async def set_graph_active_version(
|
|||
version=new_active_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Keep the library agent up to date with the new active version
|
||||
await backend.server.v2.library.db.update_agent_version_in_library(
|
||||
user_id, new_active_graph.id, new_active_graph.version
|
||||
)
|
||||
|
||||
if current_active_graph and current_active_graph.version != new_active_version:
|
||||
# Handle deactivation of the previously active version
|
||||
await on_graph_deactivate(
|
||||
|
@ -368,12 +373,13 @@ async def set_graph_active_version(
|
|||
)
|
||||
def execute_graph(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
graph_exec = execution_manager_client().add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||
)
|
||||
return {"id": graph_exec.graph_exec_id}
|
||||
except Exception as e:
|
||||
|
@ -428,47 +434,6 @@ async def get_graph_run_node_execution_results(
|
|||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Templates ########################
|
||||
########################################################
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/templates",
|
||||
tags=["graphs", "templates"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/templates/{graph_id}",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_template(
|
||||
graph_id: str, version: int | None = None
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
||||
return graph
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
path="/templates",
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def create_new_template(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
|
||||
########################################################
|
||||
##################### Schedules ########################
|
||||
########################################################
|
||||
|
@ -541,7 +506,7 @@ def get_execution_schedules(
|
|||
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
response_model=list[CreateAPIKeyResponse] | dict[str, str],
|
||||
response_model=CreateAPIKeyResponse,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
@ -583,7 +548,7 @@ async def get_api_keys(
|
|||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
@ -604,7 +569,7 @@ async def get_api_key(
|
|||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
@ -626,7 +591,7 @@ async def delete_api_key(
|
|||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
@ -648,7 +613,7 @@ async def suspend_key(
|
|||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
response_model=APIKeyWithoutHash,
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
|
@ -7,6 +7,7 @@ import prisma.types
|
|||
|
||||
import backend.data.graph
|
||||
import backend.data.includes
|
||||
import backend.server.model
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
|
@ -14,90 +15,152 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
async def get_library_agents(
|
||||
user_id: str,
|
||||
) -> List[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
|
||||
"""
|
||||
logger.debug(f"Getting library agents for user {user_id}")
|
||||
user_id: str, search_query: str | None = None
|
||||
) -> list[backend.server.v2.library.model.LibraryAgent]:
|
||||
logger.debug(
|
||||
f"Fetching library agents for user_id={user_id} search_query={search_query}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get agents created by user with nodes and links
|
||||
user_created = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
if search_query and len(search_query.strip()) > 100:
|
||||
logger.warning(f"Search query too long: {search_query}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Search query is too long."
|
||||
)
|
||||
|
||||
# Get agents in user's library with nodes and links
|
||||
library_agents = await prisma.models.UserAgent.prisma().find_many(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId=user_id, isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
where_clause = prisma.types.LibraryAgentWhereInput(
|
||||
userId=user_id,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
)
|
||||
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
"is": {"name": {"contains": search_query, "mode": "insensitive"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"Agent": {
|
||||
"is": {
|
||||
"description": {"contains": search_query, "mode": "insensitive"}
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {"include": {"Input": True, "Output": True}}
|
||||
}
|
||||
}
|
||||
},
|
||||
order=[{"updatedAt": "desc"}],
|
||||
)
|
||||
logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.")
|
||||
return [
|
||||
backend.server.v2.library.model.LibraryAgent.from_db(agent)
|
||||
for agent in library_agents
|
||||
]
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error fetching library agents: {e}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Unable to fetch library agents."
|
||||
)
|
||||
|
||||
# Convert to Graph models first
|
||||
graphs = []
|
||||
|
||||
# Add user created agents
|
||||
for agent in user_created:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing user created agent {agent.id}: {e}")
|
||||
continue
|
||||
async def create_library_agent(
|
||||
agent_id: str, agent_version: int, user_id: str
|
||||
) -> prisma.models.LibraryAgent:
|
||||
"""
|
||||
Adds an agent to the user's library (LibraryAgent table)
|
||||
"""
|
||||
|
||||
# Add library agents
|
||||
for agent in library_agents:
|
||||
if agent.Agent:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing library agent {agent.agentId}: {e}")
|
||||
continue
|
||||
try:
|
||||
|
||||
# Convert Graph models to LibraryAgent models
|
||||
result = []
|
||||
for graph in graphs:
|
||||
result.append(
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
isCreatedByUser=any(a.id == graph.id for a in user_created),
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
library_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentId=agent_id,
|
||||
agentVersion=agent_version,
|
||||
isCreatedByUser=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(result)} library agents")
|
||||
return result
|
||||
|
||||
)
|
||||
return library_agent
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting library agents: {str(e)}")
|
||||
logger.error(f"Database error creating agent to library: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch library agents"
|
||||
"Failed to create agent to library"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
|
||||
async def update_agent_version_in_library(
|
||||
user_id: str, agent_id: str, agent_version: int
|
||||
) -> None:
|
||||
"""
|
||||
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
|
||||
Updates the agent version in the library
|
||||
"""
|
||||
try:
|
||||
await prisma.models.LibraryAgent.prisma().update(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent_id,
|
||||
"useGraphIsActiveVersion": True,
|
||||
},
|
||||
data=prisma.types.LibraryAgentUpdateInput(
|
||||
Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput(
|
||||
connect=prisma.types.AgentGraphWhereUniqueInput(
|
||||
id=agent_id,
|
||||
version=agent_version,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating agent version in library: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update agent version in library"
|
||||
) from e
|
||||
|
||||
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
auto_update_version: bool = False,
|
||||
is_favorite: bool = False,
|
||||
is_archived: bool = False,
|
||||
is_deleted: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Updates the library agent with the given fields
|
||||
"""
|
||||
try:
|
||||
await prisma.models.LibraryAgent.prisma().update(
|
||||
where={"id": library_agent_id, "userId": user_id},
|
||||
data=prisma.types.LibraryAgentUpdateInput(
|
||||
useGraphIsActiveVersion=auto_update_version,
|
||||
isFavorite=is_favorite,
|
||||
isArchived=is_archived,
|
||||
isDeleted=is_deleted,
|
||||
),
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error updating library agent: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to update library agent"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_store_agent_to_library(
|
||||
store_listing_version_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
|
||||
if they don't already have it
|
||||
"""
|
||||
logger.debug(
|
||||
|
@ -131,7 +194,7 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
|||
)
|
||||
|
||||
# Check if user already has this agent
|
||||
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
|
||||
existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent.id,
|
||||
|
@ -145,9 +208,9 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
|||
)
|
||||
return
|
||||
|
||||
# Create UserAgent entry
|
||||
await prisma.models.UserAgent.prisma().create(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
# Create LibraryAgent entry
|
||||
await prisma.models.LibraryAgent.prisma().create(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentId=agent.id,
|
||||
agentVersion=agent.version,
|
||||
|
@ -163,3 +226,116 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
|||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to add agent to library"
|
||||
) from e
|
||||
|
||||
|
||||
##############################################
|
||||
########### Presets DB Functions #############
|
||||
##############################################
|
||||
|
||||
|
||||
async def get_presets(
|
||||
user_id: str, page: int, page_size: int
|
||||
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
|
||||
|
||||
try:
|
||||
presets = await prisma.models.AgentPreset.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
skip=page * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
|
||||
total_items = await prisma.models.AgentPreset.prisma().count(
|
||||
where={"userId": user_id},
|
||||
)
|
||||
total_pages = (total_items + page_size - 1) // page_size
|
||||
|
||||
presets = [
|
||||
backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
|
||||
for preset in presets
|
||||
]
|
||||
|
||||
return backend.server.v2.library.model.LibraryAgentPresetResponse(
|
||||
presets=presets,
|
||||
pagination=backend.server.model.Pagination(
|
||||
total_items=total_items,
|
||||
total_pages=total_pages,
|
||||
current_page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting presets: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch presets"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_preset(
|
||||
user_id: str, preset_id: str
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset | None:
|
||||
try:
|
||||
preset = await prisma.models.AgentPreset.prisma().find_unique(
|
||||
where={"id": preset_id, "userId": user_id}, include={"InputPresets": True}
|
||||
)
|
||||
if not preset:
|
||||
return None
|
||||
return backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting preset: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch preset"
|
||||
) from e
|
||||
|
||||
|
||||
async def create_or_update_preset(
|
||||
user_id: str,
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
preset_id: str | None = None,
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||
try:
|
||||
new_preset = await prisma.models.AgentPreset.prisma().upsert(
|
||||
where={
|
||||
"id": preset_id if preset_id else "",
|
||||
},
|
||||
data={
|
||||
"create": {
|
||||
"userId": user_id,
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"agentId": preset.agent_id,
|
||||
"agentVersion": preset.agent_version,
|
||||
"isActive": preset.is_active,
|
||||
"InputPresets": {
|
||||
"create": [
|
||||
{"name": name, "data": json.dumps(data)}
|
||||
for name, data in preset.inputs.items()
|
||||
]
|
||||
},
|
||||
},
|
||||
"update": {
|
||||
"name": preset.name,
|
||||
"description": preset.description,
|
||||
"isActive": preset.is_active,
|
||||
},
|
||||
},
|
||||
)
|
||||
return backend.server.v2.library.model.LibraryAgentPreset.from_db(new_preset)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error creating preset: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to create preset"
|
||||
) from e
|
||||
|
||||
|
||||
async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||
try:
|
||||
await prisma.models.AgentPreset.prisma().update(
|
||||
where={"id": preset_id, "userId": user_id},
|
||||
data={"isDeleted": True},
|
||||
)
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error deleting preset: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to delete preset"
|
||||
) from e
|
||||
|
|
|
@ -37,7 +37,7 @@ async def test_get_library_agents(mocker):
|
|||
]
|
||||
|
||||
mock_library_agents = [
|
||||
prisma.models.UserAgent(
|
||||
prisma.models.LibraryAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId="agent2",
|
||||
|
@ -48,6 +48,7 @@ async def test_get_library_agents(mocker):
|
|||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
useGraphIsActiveVersion=True,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
|
@ -67,8 +68,8 @@ async def test_get_library_agents(mocker):
|
|||
return_value=mock_user_created
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_many = mocker.AsyncMock(
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
|
||||
|
@ -76,40 +77,16 @@ async def test_get_library_agents(mocker):
|
|||
result = await db.get_library_agents("test-user")
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "agent1"
|
||||
assert result[0].name == "Test Agent 1"
|
||||
assert result[0].description == "Test Description 1"
|
||||
assert result[0].isCreatedByUser is True
|
||||
assert result[1].id == "agent2"
|
||||
assert result[1].name == "Test Agent 2"
|
||||
assert result[1].description == "Test Description 2"
|
||||
assert result[1].isCreatedByUser is False
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
mock_user_agent.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId="test-user", isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "ua1"
|
||||
assert result[0].name == "Test Agent 2"
|
||||
assert result[0].description == "Test Description 2"
|
||||
assert result[0].is_created_by_user is False
|
||||
assert result[0].is_latest_version is True
|
||||
assert result[0].is_favorite is False
|
||||
assert result[0].agent_id == "agent2"
|
||||
assert result[0].agent_version == 1
|
||||
assert result[0].preset_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -152,26 +129,26 @@ async def test_add_agent_to_library(mocker):
|
|||
return_value=mock_store_listing
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_user_agent.return_value.create = mocker.AsyncMock()
|
||||
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_library_agent.return_value.create = mocker.AsyncMock()
|
||||
|
||||
# Call function
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
mock_user_agent.return_value.find_first.assert_called_once_with(
|
||||
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
}
|
||||
)
|
||||
mock_user_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
mock_library_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.LibraryAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
)
|
||||
)
|
||||
|
@ -189,7 +166,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
|||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
await db.add_store_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
|
|
|
@ -1,16 +1,112 @@
|
|||
import datetime
|
||||
import json
|
||||
import typing
|
||||
|
||||
import prisma.models
|
||||
import pydantic
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
import backend.server.model
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
id: str # Changed from agent_id to match GraphMeta
|
||||
version: int # Changed from agent_version to match GraphMeta
|
||||
is_active: bool # Added to match GraphMeta
|
||||
|
||||
agent_id: str
|
||||
agent_version: int # Changed from agent_version to match GraphMeta
|
||||
|
||||
preset_id: str | None
|
||||
|
||||
updated_at: datetime.datetime
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
isCreatedByUser: bool
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
|
||||
is_favorite: bool
|
||||
is_created_by_user: bool
|
||||
|
||||
is_latest_version: bool
|
||||
|
||||
@staticmethod
|
||||
def from_db(agent: prisma.models.LibraryAgent):
|
||||
if not agent.Agent:
|
||||
raise ValueError("AgentGraph is required")
|
||||
|
||||
graph = backend.data.graph.GraphModel.from_db(agent.Agent)
|
||||
|
||||
agent_updated_at = agent.Agent.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
|
||||
updated_at = (
|
||||
max(agent_updated_at, lib_agent_updated_at)
|
||||
if agent_updated_at
|
||||
else lib_agent_updated_at
|
||||
)
|
||||
|
||||
return LibraryAgent(
|
||||
id=agent.id,
|
||||
agent_id=agent.agentId,
|
||||
agent_version=agent.agentVersion,
|
||||
updated_at=updated_at,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
is_favorite=agent.isFavorite,
|
||||
is_created_by_user=agent.isCreatedByUser,
|
||||
is_latest_version=graph.is_active,
|
||||
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
|
||||
)
|
||||
|
||||
|
||||
class LibraryAgentPreset(pydantic.BaseModel):
|
||||
id: str
|
||||
updated_at: datetime.datetime
|
||||
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
is_active: bool
|
||||
|
||||
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
|
||||
|
||||
@staticmethod
|
||||
def from_db(preset: prisma.models.AgentPreset):
|
||||
input_data = {}
|
||||
|
||||
for data in preset.InputPresets or []:
|
||||
input_data[data.name] = json.loads(data.data)
|
||||
|
||||
return LibraryAgentPreset(
|
||||
id=preset.id,
|
||||
updated_at=preset.updatedAt,
|
||||
agent_id=preset.agentId,
|
||||
agent_version=preset.agentVersion,
|
||||
name=preset.name,
|
||||
description=preset.description,
|
||||
is_active=preset.isActive,
|
||||
inputs=input_data,
|
||||
)
|
||||
|
||||
|
||||
class LibraryAgentPresetResponse(pydantic.BaseModel):
|
||||
presets: list[LibraryAgentPreset]
|
||||
pagination: backend.server.model.Pagination
|
||||
|
||||
|
||||
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
is_active: bool
|
||||
|
|
|
@ -1,23 +1,35 @@
|
|||
import datetime
|
||||
|
||||
import prisma.models
|
||||
|
||||
import backend.data.block
|
||||
import backend.server.model
|
||||
import backend.server.v2.library.model
|
||||
|
||||
|
||||
def test_library_agent():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-123",
|
||||
version=1,
|
||||
is_active=True,
|
||||
agent_id="agent-123",
|
||||
agent_version=1,
|
||||
preset_id=None,
|
||||
updated_at=datetime.datetime.now(),
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
is_favorite=False,
|
||||
is_created_by_user=False,
|
||||
is_latest_version=True,
|
||||
)
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.agent_id == "agent-123"
|
||||
assert agent.agent_version == 1
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test description"
|
||||
assert agent.isCreatedByUser is False
|
||||
assert agent.is_favorite is False
|
||||
assert agent.is_created_by_user is False
|
||||
assert agent.is_latest_version is True
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
|
||||
|
@ -25,19 +37,148 @@ def test_library_agent():
|
|||
def test_library_agent_with_user_created():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="user-agent-456",
|
||||
version=2,
|
||||
is_active=True,
|
||||
agent_id="agent-456",
|
||||
agent_version=2,
|
||||
preset_id=None,
|
||||
updated_at=datetime.datetime.now(),
|
||||
name="User Created Agent",
|
||||
description="An agent created by the user",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
is_favorite=False,
|
||||
is_created_by_user=True,
|
||||
is_latest_version=True,
|
||||
)
|
||||
assert agent.id == "user-agent-456"
|
||||
assert agent.version == 2
|
||||
assert agent.is_active is True
|
||||
assert agent.agent_id == "agent-456"
|
||||
assert agent.agent_version == 2
|
||||
assert agent.name == "User Created Agent"
|
||||
assert agent.description == "An agent created by the user"
|
||||
assert agent.isCreatedByUser is True
|
||||
assert agent.is_favorite is False
|
||||
assert agent.is_created_by_user is True
|
||||
assert agent.is_latest_version is True
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_library_agent_preset():
|
||||
preset = backend.server.v2.library.model.LibraryAgentPreset(
|
||||
id="preset-123",
|
||||
name="Test Preset",
|
||||
description="Test preset description",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
is_active=True,
|
||||
inputs={
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1",
|
||||
data={"type": "string", "value": "test value"},
|
||||
)
|
||||
},
|
||||
updated_at=datetime.datetime.now(),
|
||||
)
|
||||
assert preset.id == "preset-123"
|
||||
assert preset.name == "Test Preset"
|
||||
assert preset.description == "Test preset description"
|
||||
assert preset.agent_id == "test-agent-123"
|
||||
assert preset.agent_version == 1
|
||||
assert preset.is_active is True
|
||||
assert preset.inputs == {
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1", data={"type": "string", "value": "test value"}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_library_agent_preset_response():
|
||||
preset = backend.server.v2.library.model.LibraryAgentPreset(
|
||||
id="preset-123",
|
||||
name="Test Preset",
|
||||
description="Test preset description",
|
||||
agent_id="test-agent-123",
|
||||
agent_version=1,
|
||||
is_active=True,
|
||||
inputs={
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1",
|
||||
data={"type": "string", "value": "test value"},
|
||||
)
|
||||
},
|
||||
updated_at=datetime.datetime.now(),
|
||||
)
|
||||
|
||||
pagination = backend.server.model.Pagination(
|
||||
total_items=1, total_pages=1, current_page=1, page_size=10
|
||||
)
|
||||
|
||||
response = backend.server.v2.library.model.LibraryAgentPresetResponse(
|
||||
presets=[preset], pagination=pagination
|
||||
)
|
||||
|
||||
assert len(response.presets) == 1
|
||||
assert response.presets[0].id == "preset-123"
|
||||
assert response.pagination.total_items == 1
|
||||
assert response.pagination.total_pages == 1
|
||||
assert response.pagination.current_page == 1
|
||||
assert response.pagination.page_size == 10
|
||||
|
||||
|
||||
def test_create_library_agent_preset_request():
|
||||
request = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="New Preset",
|
||||
description="New preset description",
|
||||
agent_id="agent-123",
|
||||
agent_version=1,
|
||||
is_active=True,
|
||||
inputs={
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1",
|
||||
data={"type": "string", "value": "test value"},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert request.name == "New Preset"
|
||||
assert request.description == "New preset description"
|
||||
assert request.agent_id == "agent-123"
|
||||
assert request.agent_version == 1
|
||||
assert request.is_active is True
|
||||
assert request.inputs == {
|
||||
"input1": backend.data.block.BlockInput(
|
||||
name="input1", data={"type": "string", "value": "test value"}
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_library_agent_from_db():
|
||||
# Create mock DB agent
|
||||
db_agent = prisma.models.AgentPreset(
|
||||
id="test-agent-123",
|
||||
createdAt=datetime.datetime.now(),
|
||||
updatedAt=datetime.datetime.now(),
|
||||
agentId="agent-123",
|
||||
agentVersion=1,
|
||||
name="Test Agent",
|
||||
description="Test agent description",
|
||||
isActive=True,
|
||||
userId="test-user-123",
|
||||
isDeleted=False,
|
||||
InputPresets=[
|
||||
prisma.models.AgentNodeExecutionInputOutput(
|
||||
id="input-123",
|
||||
time=datetime.datetime.now(),
|
||||
name="input1",
|
||||
data='{"type": "string", "value": "test value"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Convert to LibraryAgentPreset
|
||||
agent = backend.server.v2.library.model.LibraryAgentPreset.from_db(db_agent)
|
||||
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.agent_version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test agent description"
|
||||
assert agent.inputs == {"input1": {"type": "string", "value": "test value"}}
|
||||
|
|
|
@ -1,123 +0,0 @@
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import prisma
|
||||
|
||||
import backend.data.graph
|
||||
import backend.integrations.creds_manager
|
||||
import backend.integrations.webhooks.graph_lifecycle_hooks
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
integration_creds_manager = (
|
||||
backend.integrations.creds_manager.IntegrationCredentialsManager()
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_library_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Get all agents in the user's library, including both created and saved agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting library agents")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to get library agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{store_listing_version_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=201,
|
||||
)
|
||||
async def add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Add an agent from the store to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to add
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 201 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error adding the agent to the library
|
||||
"""
|
||||
try:
|
||||
# Get the graph from the store listing
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
if agent.userId == user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400, detail="Cannot add own agent to library"
|
||||
)
|
||||
|
||||
# Create a new graph from the template
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, template=True, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent {agent.id} not found"
|
||||
)
|
||||
|
||||
# Create a deep copy with new IDs
|
||||
graph.version = 1
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
# Save the new graph
|
||||
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
|
||||
graph = (
|
||||
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
|
||||
graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
)
|
||||
|
||||
return fastapi.Response(status_code=201)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst adding agent to library")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
|
@ -0,0 +1,9 @@
|
|||
import fastapi
|
||||
|
||||
from .agents import router as agents_router
|
||||
from .presets import router as presets_router
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
router.include_router(presets_router)
|
||||
router.include_router(agents_router)
|
|
@ -0,0 +1,148 @@
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import autogpt_libs.utils.cache
|
||||
import fastapi
|
||||
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_library_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Get all agents in the user's library, including both created and saved agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||
return agents
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst getting library agents: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to get library agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{store_listing_version_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=201,
|
||||
)
|
||||
async def add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Add an agent from the store to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to add
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 201 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error adding the agent to the library
|
||||
"""
|
||||
try:
|
||||
# Use the database function to add the agent to the library
|
||||
await backend.server.v2.library.db.add_store_agent_to_library(
|
||||
store_listing_version_id, user_id
|
||||
)
|
||||
return fastapi.Response(status_code=201)
|
||||
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
except backend.server.v2.store.exceptions.DatabaseError as e:
|
||||
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unexpected exception occurred whilst adding agent to library: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/agents/{library_agent_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=204,
|
||||
)
|
||||
async def update_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
auto_update_version: bool = False,
|
||||
is_favorite: bool = False,
|
||||
is_archived: bool = False,
|
||||
is_deleted: bool = False,
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Update the library agent with the given fields.
|
||||
|
||||
Args:
|
||||
library_agent_id (str): ID of the library agent to update
|
||||
user_id (str): ID of the authenticated user
|
||||
auto_update_version (bool): Whether to auto-update the agent version
|
||||
is_favorite (bool): Whether the agent is marked as favorite
|
||||
is_archived (bool): Whether the agent is archived
|
||||
is_deleted (bool): Whether the agent is deleted
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 204 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error updating the library agent
|
||||
"""
|
||||
try:
|
||||
# Use the database function to update the library agent
|
||||
await backend.server.v2.library.db.update_library_agent(
|
||||
library_agent_id,
|
||||
user_id,
|
||||
auto_update_version,
|
||||
is_favorite,
|
||||
is_archived,
|
||||
is_deleted,
|
||||
)
|
||||
return fastapi.Response(status_code=204)
|
||||
|
||||
except backend.server.v2.store.exceptions.DatabaseError as e:
|
||||
logger.exception(f"Database error occurred whilst updating library agent: {e}")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to update library agent"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unexpected exception occurred whilst updating library agent: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to update library agent"
|
||||
)
|
|
@ -0,0 +1,156 @@
|
|||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import autogpt_libs.utils.cache
|
||||
import fastapi
|
||||
|
||||
import backend.data.graph
|
||||
import backend.executor
|
||||
import backend.integrations.creds_manager
|
||||
import backend.integrations.webhooks.graph_lifecycle_hooks
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.util.service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
integration_creds_manager = (
|
||||
backend.integrations.creds_manager.IntegrationCredentialsManager()
|
||||
)
|
||||
|
||||
|
||||
@autogpt_libs.utils.cache.thread_cached
|
||||
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||
|
||||
|
||||
@router.get("/presets")
|
||||
async def get_presets(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
|
||||
try:
|
||||
presets = await backend.server.v2.library.db.get_presets(
|
||||
user_id, page, page_size
|
||||
)
|
||||
return presets
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst getting presets: {e}")
|
||||
raise fastapi.HTTPException(status_code=500, detail="Failed to get presets")
|
||||
|
||||
|
||||
@router.get("/presets/{preset_id}")
|
||||
async def get_preset(
|
||||
preset_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||
try:
|
||||
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Preset {preset_id} not found",
|
||||
)
|
||||
return preset
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst getting preset: {e}")
|
||||
raise fastapi.HTTPException(status_code=500, detail="Failed to get preset")
|
||||
|
||||
|
||||
@router.post("/presets")
|
||||
async def create_preset(
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||
try:
|
||||
return await backend.server.v2.library.db.create_or_update_preset(
|
||||
user_id, preset
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst creating preset: {e}")
|
||||
raise fastapi.HTTPException(status_code=500, detail="Failed to create preset")
|
||||
|
||||
|
||||
@router.put("/presets/{preset_id}")
|
||||
async def update_preset(
|
||||
preset_id: str,
|
||||
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||
try:
|
||||
return await backend.server.v2.library.db.create_or_update_preset(
|
||||
user_id, preset, preset_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst updating preset: {e}")
|
||||
raise fastapi.HTTPException(status_code=500, detail="Failed to update preset")
|
||||
|
||||
|
||||
@router.delete("/presets/{preset_id}")
|
||||
async def delete_preset(
|
||||
preset_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
):
|
||||
try:
|
||||
await backend.server.v2.library.db.delete_preset(user_id, preset_id)
|
||||
return fastapi.Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred whilst deleting preset: {e}")
|
||||
raise fastapi.HTTPException(status_code=500, detail="Failed to delete preset")
|
||||
|
||||
|
||||
@router.post(
|
||||
path="/presets/{preset_id}/execute",
|
||||
tags=["presets"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def execute_preset(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
preset_id: str,
|
||||
node_input: dict[typing.Any, typing.Any],
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> dict[str, typing.Any]: # FIXME: add proper return type
|
||||
try:
|
||||
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
|
||||
if not preset:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
|
||||
|
||||
logger.info(f"Preset inputs: {preset.inputs}")
|
||||
|
||||
updated_node_input = node_input.copy()
|
||||
# Merge in preset input values
|
||||
for key, value in preset.inputs.items():
|
||||
if key not in updated_node_input:
|
||||
updated_node_input[key] = value
|
||||
|
||||
execution = execution_manager_client().add_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version,
|
||||
data=updated_node_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
|
||||
logger.info(f"Execution added: {execution} with input: {updated_node_input}")
|
||||
|
||||
return {"id": execution.graph_exec_id}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise fastapi.HTTPException(status_code=400, detail=msg)
|
|
@ -1,3 +1,5 @@
|
|||
import datetime
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
|
@ -35,21 +37,29 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
|||
mocked_value = [
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
version=1,
|
||||
is_active=True,
|
||||
agent_id="test-agent-1",
|
||||
agent_version=1,
|
||||
preset_id="preset-1",
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
is_favorite=False,
|
||||
is_created_by_user=True,
|
||||
is_latest_version=True,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
version=1,
|
||||
is_active=True,
|
||||
agent_id="test-agent-2",
|
||||
agent_version=1,
|
||||
preset_id="preset-2",
|
||||
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
is_favorite=False,
|
||||
is_created_by_user=False,
|
||||
is_latest_version=True,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
|
@ -65,10 +75,10 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
|||
for agent in response.json()
|
||||
]
|
||||
assert len(data) == 2
|
||||
assert data[0].id == "test-agent-1"
|
||||
assert data[0].isCreatedByUser is True
|
||||
assert data[1].id == "test-agent-2"
|
||||
assert data[1].isCreatedByUser is False
|
||||
assert data[0].agent_id == "test-agent-1"
|
||||
assert data[0].is_created_by_user is True
|
||||
assert data[1].agent_id == "test-agent-2"
|
||||
assert data[1].is_created_by_user is False
|
||||
mock_db_call.assert_called_once_with("test-user-id")
|
||||
|
||||
|
||||
|
|
|
@ -325,7 +325,10 @@ async def get_store_submissions(
|
|||
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
|
||||
# Query submissions from database
|
||||
submissions = await prisma.models.StoreSubmission.prisma().find_many(
|
||||
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
|
||||
where=where,
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
order=[{"date_submitted": "desc"}],
|
||||
)
|
||||
|
||||
# Get total count for pagination
|
||||
|
@ -405,9 +408,7 @@ async def delete_store_submission(
|
|||
)
|
||||
|
||||
# Delete the submission
|
||||
await prisma.models.StoreListing.prisma().delete(
|
||||
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
|
||||
)
|
||||
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||
|
||||
logger.debug(
|
||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||
|
@ -504,7 +505,15 @@ async def create_store_submission(
|
|||
"subHeading": sub_heading,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
include={"StoreListingVersions": True},
|
||||
)
|
||||
|
||||
slv_id = (
|
||||
listing.StoreListingVersions[0].id
|
||||
if listing.StoreListingVersions is not None
|
||||
and len(listing.StoreListingVersions) > 0
|
||||
else None
|
||||
)
|
||||
|
||||
logger.debug(f"Created store listing for agent {agent_id}")
|
||||
|
@ -521,6 +530,7 @@ async def create_store_submission(
|
|||
status=prisma.enums.SubmissionStatus.PENDING,
|
||||
runs=0,
|
||||
rating=0.0,
|
||||
store_listing_version_id=slv_id,
|
||||
)
|
||||
|
||||
except (
|
||||
|
@ -811,9 +821,7 @@ async def get_agent(
|
|||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, template=True
|
||||
)
|
||||
graph = await backend.data.graph.get_graph(agent.id, agent.version)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
|
@ -832,3 +840,74 @@ async def get_agent(
|
|||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent"
|
||||
) from e
|
||||
|
||||
|
||||
async def review_store_submission(
|
||||
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
|
||||
) -> prisma.models.StoreListingSubmission:
|
||||
"""Review a store listing submission."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id},
|
||||
include={"StoreListing": True},
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.StoreListing:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
status = (
|
||||
prisma.enums.SubmissionStatus.APPROVED
|
||||
if is_approved
|
||||
else prisma.enums.SubmissionStatus.REJECTED
|
||||
)
|
||||
|
||||
create_data = prisma.types.StoreListingSubmissionCreateInput(
|
||||
StoreListingVersion={"connect": {"id": store_listing_version_id}},
|
||||
Status=status,
|
||||
reviewComments=comments,
|
||||
Reviewer={"connect": {"id": reviewer_id}},
|
||||
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
update_data = prisma.types.StoreListingSubmissionUpdateInput(
|
||||
Status=status,
|
||||
reviewComments=comments,
|
||||
Reviewer={"connect": {"id": reviewer_id}},
|
||||
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
if is_approved:
|
||||
await prisma.models.StoreListing.prisma().update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={"isApproved": True},
|
||||
)
|
||||
|
||||
submission = await prisma.models.StoreListingSubmission.prisma().upsert(
|
||||
where={"storeListingVersionId": store_listing_version_id},
|
||||
data=prisma.types.StoreListingSubmissionUpsertInput(
|
||||
create=create_data,
|
||||
update=update_data,
|
||||
),
|
||||
)
|
||||
|
||||
if not submission:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing submission {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
return submission
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error reviewing store submission: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to review store submission"
|
||||
) from e
|
||||
|
|
|
@ -115,6 +115,7 @@ class StoreSubmission(pydantic.BaseModel):
|
|||
status: prisma.enums.SubmissionStatus
|
||||
runs: int
|
||||
rating: float
|
||||
store_listing_version_id: str | None = None
|
||||
|
||||
|
||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||
|
@ -151,3 +152,9 @@ class StoreReviewCreate(pydantic.BaseModel):
|
|||
store_listing_version_id: str
|
||||
score: int
|
||||
comments: str | None = None
|
||||
|
||||
|
||||
class ReviewSubmissionRequest(pydantic.BaseModel):
|
||||
store_listing_version_id: str
|
||||
isApproved: bool
|
||||
comments: str
|
||||
|
|
|
@ -642,3 +642,33 @@ async def download_agent_file(
|
|||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/review/{store_listing_version_id}",
|
||||
tags=["store", "private"],
|
||||
)
|
||||
async def review_submission(
|
||||
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||
user: typing.Annotated[
|
||||
autogpt_libs.auth.models.User,
|
||||
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||
],
|
||||
):
|
||||
# Proceed with the review submission logic
|
||||
try:
|
||||
submission = await backend.server.v2.store.db.review_store_submission(
|
||||
store_listing_version_id=request.store_listing_version_id,
|
||||
is_approved=request.isApproved,
|
||||
comments=request.comments,
|
||||
reviewer_id=user.user_id,
|
||||
)
|
||||
return submission
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst reviewing store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while reviewing the store submission"
|
||||
},
|
||||
)
|
||||
|
|
|
@ -253,7 +253,7 @@ async def block_autogen_agent():
|
|||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
|
|
|
@ -157,7 +157,7 @@ async def reddit_marketing_agent():
|
|||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
|
||||
|
|
|
@ -8,12 +8,19 @@ from backend.data.user import get_or_create_user
|
|||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
async def create_test_user() -> User:
|
||||
test_user_data = {
|
||||
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
||||
"email": "testuser#example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
async def create_test_user(alt_user: bool = False) -> User:
|
||||
if alt_user:
|
||||
test_user_data = {
|
||||
"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||
"email": "testuser2#example.com",
|
||||
"name": "Test User 2",
|
||||
}
|
||||
else:
|
||||
test_user_data = {
|
||||
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
||||
"email": "testuser#example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
user = await get_or_create_user(test_user_data)
|
||||
return user
|
||||
|
||||
|
@ -38,7 +45,7 @@ def create_test_graph() -> graph.Graph:
|
|||
graph.Node(
|
||||
block_id=FillTextTemplateBlock().id,
|
||||
input_default={
|
||||
"format": "{a}, {b}{c}",
|
||||
"format": "{{a}}, {{b}}{{c}}",
|
||||
"values_#_c": "!!!",
|
||||
},
|
||||
),
|
||||
|
@ -79,7 +86,7 @@ async def sample_agent():
|
|||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
|
||||
|
|
|
@ -264,6 +264,10 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||
notion_client_secret: str = Field(
|
||||
default="", description="Notion OAuth client secret"
|
||||
)
|
||||
twitter_client_id: str = Field(default="", description="Twitter/X OAuth client ID")
|
||||
twitter_client_secret: str = Field(
|
||||
default="", description="Twitter/X OAuth client secret"
|
||||
)
|
||||
|
||||
openai_api_key: str = Field(default="", description="OpenAI API key")
|
||||
anthropic_api_key: str = Field(default="", description="Anthropic API key")
|
||||
|
@ -300,7 +304,10 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||
jina_api_key: str = Field(default="", description="Jina API Key")
|
||||
unreal_speech_api_key: str = Field(default="", description="Unreal Speech API Key")
|
||||
|
||||
fal_key: str = Field(default="", description="FAL API key")
|
||||
fal_api_key: str = Field(default="", description="FAL API key")
|
||||
exa_api_key: str = Field(default="", description="Exa API key")
|
||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||
|
||||
# Add more secret fields as needed
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import re
|
||||
|
||||
from jinja2 import BaseLoader
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
|
@ -15,8 +13,5 @@ class TextFormatter:
|
|||
self.env.globals.clear()
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
# For python.format compatibility: replace all {...} with {{..}}.
|
||||
# But avoid replacing {{...}} to {{{...}}}.
|
||||
template_str = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", template_str)
|
||||
template = self.env.from_string(template_str)
|
||||
return template.render(values or {}, **kwargs)
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
Warnings:
|
||||
- You are about replace a single brace string input format for the following blocks:
|
||||
- AgentOutputBlock
|
||||
- FillTextTemplateBlock
|
||||
- AITextGeneratorBlock
|
||||
- AIStructuredResponseGeneratorBlock
|
||||
with a double brace format.
|
||||
- This migration can be slow for a large updated AgentNode tables.
|
||||
*/
|
||||
BEGIN;
|
||||
SET LOCAL statement_timeout = '10min';
|
||||
|
||||
WITH to_update AS (
|
||||
SELECT
|
||||
"id",
|
||||
"agentBlockId",
|
||||
"constantInput"::jsonb AS j
|
||||
FROM "AgentNode"
|
||||
WHERE
|
||||
"agentBlockId" IN (
|
||||
'363ae599-353e-4804-937e-b2ee3cef3da4', -- AgentOutputBlock
|
||||
'db7d8f02-2f44-4c55-ab7a-eae0941f0c30', -- FillTextTemplateBlock
|
||||
'1f292d4a-41a4-4977-9684-7c8d560b9f91', -- AITextGeneratorBlock
|
||||
'ed55ac19-356e-4243-a6cb-bc599e9b716f' -- AIStructuredResponseGeneratorBlock
|
||||
)
|
||||
AND (
|
||||
"constantInput"::jsonb->>'format' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
|
||||
OR "constantInput"::jsonb->>'prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
|
||||
OR "constantInput"::jsonb->>'sys_prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
|
||||
)
|
||||
),
|
||||
updated_rows AS (
|
||||
SELECT
|
||||
"id",
|
||||
"agentBlockId",
|
||||
(
|
||||
j
|
||||
-- Update "format" if it has a single-brace placeholder
|
||||
|| CASE WHEN j->>'format' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
|
||||
THEN jsonb_build_object(
|
||||
'format',
|
||||
regexp_replace(
|
||||
j->>'format',
|
||||
'(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})',
|
||||
'{{\1}}',
|
||||
'g'
|
||||
)
|
||||
)
|
||||
ELSE '{}'::jsonb
|
||||
END
|
||||
-- Update "prompt" if it has a single-brace placeholder
|
||||
|| CASE WHEN j->>'prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
|
||||
THEN jsonb_build_object(
|
||||
'prompt',
|
||||
regexp_replace(
|
||||
j->>'prompt',
|
||||
'(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})',
|
||||
'{{\1}}',
|
||||
'g'
|
||||
)
|
||||
)
|
||||
ELSE '{}'::jsonb
|
||||
END
|
||||
-- Update "sys_prompt" if it has a single-brace placeholder
|
||||
|| CASE WHEN j->>'sys_prompt' ~ '(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})'
|
||||
THEN jsonb_build_object(
|
||||
'sys_prompt',
|
||||
regexp_replace(
|
||||
j->>'sys_prompt',
|
||||
'(?<!\{)\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}(?!\})',
|
||||
'{{\1}}',
|
||||
'g'
|
||||
)
|
||||
)
|
||||
ELSE '{}'::jsonb
|
||||
END
|
||||
)::text AS "newConstantInput"
|
||||
FROM to_update
|
||||
)
|
||||
UPDATE "AgentNode" AS an
|
||||
SET "constantInput" = ur."newConstantInput"
|
||||
FROM updated_rows ur
|
||||
WHERE an."id" = ur."id";
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,2 @@
|
|||
-- AlterTable
|
||||
ALTER TABLE "AgentPreset" ADD COLUMN "isDeleted" BOOLEAN NOT NULL DEFAULT false;
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the `UserAgent` table. If the table is not empty, all the data it contains will be lost.
|
||||
|
||||
*/
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentId_agentVersion_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentPresetId_fkey";
|
||||
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_userId_fkey";
|
||||
|
||||
-- DropTable
|
||||
DROP TABLE "UserAgent";
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LibraryAgent" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"agentId" TEXT NOT NULL,
|
||||
"agentVersion" INTEGER NOT NULL,
|
||||
"agentPresetId" TEXT,
|
||||
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isArchived" BOOLEAN NOT NULL DEFAULT false,
|
||||
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||
|
||||
CONSTRAINT "LibraryAgent_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LibraryAgent_userId_idx" ON "LibraryAgent"("userId");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
|
@ -0,0 +1,2 @@
|
|||
-- AlterTable
|
||||
ALTER TABLE "LibraryAgent" ADD COLUMN "useGraphIsActiveVersion" BOOLEAN NOT NULL DEFAULT false;
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
Warnings:
|
||||
|
||||
- A unique constraint covering the columns `[agentId]` on the table `StoreListing` will be added. If there are existing duplicate values, this will fail.
|
||||
|
||||
*/
|
||||
-- DropIndex
|
||||
DROP INDEX "StoreListing_agentId_idx";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX "StoreListing_isApproved_idx";
|
||||
|
||||
-- DropIndex
|
||||
DROP INDEX "StoreListingVersion_agentId_agentVersion_isApproved_idx";
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListing_agentId_owningUserId_idx" ON "StoreListing"("agentId", "owningUserId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListing_isDeleted_isApproved_idx" ON "StoreListing"("isDeleted", "isApproved");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListing_isDeleted_idx" ON "StoreListing"("isDeleted");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "StoreListing_agentId_key" ON "StoreListing"("agentId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreListingVersion_agentId_agentVersion_isDeleted_idx" ON "StoreListingVersion"("agentId", "agentVersion", "isDeleted");
|
File diff suppressed because it is too large
Load Diff
|
@ -39,8 +39,9 @@ python-dotenv = "^1.0.1"
|
|||
redis = "^5.2.0"
|
||||
sentry-sdk = "2.19.2"
|
||||
strenum = "^0.4.9"
|
||||
supabase = "^2.10.0"
|
||||
supabase = "2.11.0"
|
||||
tenacity = "^9.0.0"
|
||||
tweepy = "^4.14.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.34.0" }
|
||||
websockets = "^13.1"
|
||||
youtube-transcript-api = "^0.6.2"
|
||||
|
|
|
@ -30,7 +30,7 @@ model User {
|
|||
CreditTransaction CreditTransaction[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
UserAgent UserAgent[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
|
||||
Profile Profile[]
|
||||
StoreListing StoreListing[]
|
||||
|
@ -65,7 +65,7 @@ model AgentGraph {
|
|||
AgentGraphExecution AgentGraphExecution[]
|
||||
|
||||
AgentPreset AgentPreset[]
|
||||
UserAgent UserAgent[]
|
||||
LibraryAgent LibraryAgent[]
|
||||
StoreListing StoreListing[]
|
||||
StoreListingVersion StoreListingVersion?
|
||||
|
||||
|
@ -103,15 +103,17 @@ model AgentPreset {
|
|||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||
|
||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||
UserAgents UserAgent[]
|
||||
LibraryAgents LibraryAgent[]
|
||||
AgentExecution AgentGraphExecution[]
|
||||
|
||||
isDeleted Boolean @default(false)
|
||||
|
||||
@@index([userId])
|
||||
}
|
||||
|
||||
// For the library page
|
||||
// It is a user controlled list of agents, that they will see in there library
|
||||
model UserAgent {
|
||||
model LibraryAgent {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
@ -126,6 +128,8 @@ model UserAgent {
|
|||
agentPresetId String?
|
||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||
|
||||
useGraphIsActiveVersion Boolean @default(false)
|
||||
|
||||
isFavorite Boolean @default(false)
|
||||
isCreatedByUser Boolean @default(false)
|
||||
isArchived Boolean @default(false)
|
||||
|
@ -235,7 +239,7 @@ model AgentGraphExecution {
|
|||
|
||||
AgentNodeExecutions AgentNodeExecution[]
|
||||
|
||||
// Link to User model
|
||||
// Link to User model -- Executed by this user
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
|
@ -443,6 +447,8 @@ view Creator {
|
|||
agent_rating Float
|
||||
agent_runs Int
|
||||
is_featured Boolean
|
||||
|
||||
// Index or unique are not applied to views
|
||||
}
|
||||
|
||||
view StoreAgent {
|
||||
|
@ -465,11 +471,7 @@ view StoreAgent {
|
|||
rating Float
|
||||
versions String[]
|
||||
|
||||
@@unique([creator_username, slug])
|
||||
@@index([creator_username])
|
||||
@@index([featured])
|
||||
@@index([categories])
|
||||
@@index([storeListingVersionId])
|
||||
// Index or unique are not applied to views
|
||||
}
|
||||
|
||||
view StoreSubmission {
|
||||
|
@ -487,7 +489,7 @@ view StoreSubmission {
|
|||
agent_id String
|
||||
agent_version Int
|
||||
|
||||
@@index([user_id])
|
||||
// Index or unique are not applied to views
|
||||
}
|
||||
|
||||
model StoreListing {
|
||||
|
@ -510,9 +512,13 @@ model StoreListing {
|
|||
StoreListingVersions StoreListingVersion[]
|
||||
StoreListingSubmission StoreListingSubmission[]
|
||||
|
||||
@@index([isApproved])
|
||||
@@index([agentId])
|
||||
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||
@@unique([agentId])
|
||||
@@index([agentId, owningUserId])
|
||||
@@index([owningUserId])
|
||||
// Used in the view query
|
||||
@@index([isDeleted, isApproved])
|
||||
@@index([isDeleted])
|
||||
}
|
||||
|
||||
model StoreListingVersion {
|
||||
|
@ -553,7 +559,7 @@ model StoreListingVersion {
|
|||
StoreListingReview StoreListingReview[]
|
||||
|
||||
@@unique([agentId, agentVersion])
|
||||
@@index([agentId, agentVersion, isApproved])
|
||||
@@index([agentId, agentVersion, isDeleted])
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.exceptions
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
|
@ -202,3 +205,92 @@ async def test_clean_graph(server: SpinTestServer):
|
|||
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
||||
)
|
||||
assert input_node.input_default["value"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_access_store_listing_graph(server: SpinTestServer):
|
||||
"""
|
||||
Test the access of a store listing graph.
|
||||
"""
|
||||
graph = Graph(
|
||||
id="test_clean_graph",
|
||||
name="Test Clean Graph",
|
||||
description="Test graph cleaning",
|
||||
nodes=[
|
||||
Node(
|
||||
id="input_node",
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={
|
||||
"name": "test_input",
|
||||
"value": "test value",
|
||||
"description": "Test input description",
|
||||
},
|
||||
),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
# Create graph and get model
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=created_graph.id,
|
||||
agent_version=created_graph.version,
|
||||
slug="test-slug",
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
video_url=None,
|
||||
image_urls=[],
|
||||
description="Test description",
|
||||
categories=[],
|
||||
)
|
||||
|
||||
# First we check the graph an not be accessed by a different user
|
||||
with pytest.raises(fastapi.exceptions.HTTPException) as exc_info:
|
||||
await server.agent_server.test_get_graph(
|
||||
created_graph.id,
|
||||
created_graph.version,
|
||||
"3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "Graph" in str(exc_info.value.detail)
|
||||
|
||||
# Now we create a store listing
|
||||
store_listing = await server.agent_server.test_create_store_listing(
|
||||
store_submission_request, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
if isinstance(store_listing, fastapi.responses.JSONResponse):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
assert slv_id is not None
|
||||
|
||||
admin = autogpt_libs.auth.models.User(
|
||||
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||
role="admin",
|
||||
email="admin@example.com",
|
||||
phone_number="1234567890",
|
||||
)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
isApproved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
admin,
|
||||
)
|
||||
|
||||
# Now we check the graph can be accessed by a user that does not own the graph
|
||||
got_graph = await server.agent_server.test_get_graph(
|
||||
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
)
|
||||
assert got_graph is not None
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.server.model import CreateGraph
|
||||
|
@ -31,7 +35,7 @@ async def execute_graph(
|
|||
|
||||
# --- Test adding new executions --- #
|
||||
response = await agent_server.test_execute_graph(
|
||||
test_graph.id, input_data, test_user.id
|
||||
test_graph.id, test_graph.version, input_data, test_user.id
|
||||
)
|
||||
graph_exec_id = response["id"]
|
||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||
|
@ -102,7 +106,7 @@ async def assert_sample_graph_executions(
|
|||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello, World!!!"]}
|
||||
assert exec.input_data == {
|
||||
"format": "{a}, {b}{c}",
|
||||
"format": "{{a}}, {{b}}{{c}}",
|
||||
"values": {"a": "Hello", "b": "World", "c": "!!!"},
|
||||
"values_#_a": "Hello",
|
||||
"values_#_b": "World",
|
||||
|
@ -287,3 +291,255 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
|||
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec_data.output_data == {"result": [9]}
|
||||
logger.info("Completed test_static_input_link_on_graph")
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_execute_preset(server: SpinTestServer):
|
||||
"""
|
||||
Test executing a preset.
|
||||
|
||||
This test ensures that:
|
||||
1. A preset can be successfully executed
|
||||
2. The execution results are correct
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
# Create test graph and user
|
||||
nodes = [
|
||||
graph.Node( # 0
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "dictionary"},
|
||||
),
|
||||
graph.Node( # 1
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "selected_value"},
|
||||
),
|
||||
graph.Node( # 2
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
|
||||
),
|
||||
graph.Node( # 3
|
||||
block_id=FindInDictionaryBlock().id,
|
||||
input_default={"key": "", "input": {}},
|
||||
),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="result",
|
||||
sink_name="key",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
inputs={
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
node_input={},
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
assert len(executions) == 4
|
||||
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[3].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[3].output_data == {"output": ["World"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
"""
|
||||
Test executing a preset with clashing input data.
|
||||
"""
|
||||
# Create test graph and user
|
||||
nodes = [
|
||||
graph.Node( # 0
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "dictionary"},
|
||||
),
|
||||
graph.Node( # 1
|
||||
block_id=AgentInputBlock().id,
|
||||
input_default={"name": "selected_value"},
|
||||
),
|
||||
graph.Node( # 2
|
||||
block_id=StoreValueBlock().id,
|
||||
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
|
||||
),
|
||||
graph.Node( # 3
|
||||
block_id=FindInDictionaryBlock().id,
|
||||
input_default={"key": "", "input": {}},
|
||||
),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="result",
|
||||
sink_name="key",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[2].id,
|
||||
sink_id=nodes[3].id,
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
# Create preset with initial values
|
||||
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||
name="Test Preset With Clash",
|
||||
description="Test preset with clashing input values",
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
inputs={
|
||||
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||
"selected_value": "key2",
|
||||
},
|
||||
is_active=True,
|
||||
)
|
||||
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||
|
||||
# Execute preset with overriding values
|
||||
result = await server.agent_server.test_execute_preset(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=test_graph.version,
|
||||
preset_id=created_preset.id,
|
||||
node_input={"selected_value": "key1"},
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
# Verify execution
|
||||
assert result is not None
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
assert len(executions) == 4
|
||||
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[3].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[3].output_data == {"output": ["Hello"]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_store_listing_graph(server: SpinTestServer):
|
||||
logger.info("Starting test_agent_execution")
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||
|
||||
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||
agent_id=test_graph.id,
|
||||
agent_version=test_graph.version,
|
||||
slug="test-slug",
|
||||
name="Test name",
|
||||
sub_heading="Test sub heading",
|
||||
video_url=None,
|
||||
image_urls=[],
|
||||
description="Test description",
|
||||
categories=[],
|
||||
)
|
||||
|
||||
store_listing = await server.agent_server.test_create_store_listing(
|
||||
store_submission_request, test_user.id
|
||||
)
|
||||
|
||||
if isinstance(store_listing, fastapi.responses.JSONResponse):
|
||||
assert False, "Failed to create store listing"
|
||||
|
||||
slv_id = (
|
||||
store_listing.store_listing_version_id
|
||||
if store_listing.store_listing_version_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
assert slv_id is not None
|
||||
|
||||
admin = autogpt_libs.auth.models.User(
|
||||
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||
role="admin",
|
||||
email="admin@example.com",
|
||||
phone_number="1234567890",
|
||||
)
|
||||
await server.agent_server.test_review_store_listing(
|
||||
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||
store_listing_version_id=slv_id,
|
||||
isApproved=True,
|
||||
comments="Test comments",
|
||||
),
|
||||
admin,
|
||||
)
|
||||
|
||||
alt_test_user = await create_test_user(alt_user=True)
|
||||
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server,
|
||||
test_graph,
|
||||
alt_test_user,
|
||||
data,
|
||||
4,
|
||||
)
|
||||
|
||||
await assert_sample_graph_executions(
|
||||
server.agent_server, test_graph, alt_test_user, graph_exec_id
|
||||
)
|
||||
logger.info("Completed test_agent_execution")
|
||||
|
|
|
@ -140,10 +140,10 @@ async def main():
|
|||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||
for user in users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
for _ in range(num_agents): # Create 1 UserAgent per user
|
||||
for _ in range(num_agents): # Create 1 LibraryAgent per user
|
||||
graph = random.choice(agent_graphs)
|
||||
preset = random.choice(agent_presets)
|
||||
user_agent = await db.useragent.create(
|
||||
user_agent = await db.libraryagent.create(
|
||||
data={
|
||||
"userId": user.id,
|
||||
"agentId": graph.id,
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
],
|
||||
"dependencies": {
|
||||
"@faker-js/faker": "^9.3.0",
|
||||
"@hookform/resolvers": "^3.9.1",
|
||||
"@next/third-parties": "^15.0.4",
|
||||
"@hookform/resolvers": "^3.10.0",
|
||||
"@next/third-parties": "^15.1.3",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.4",
|
||||
"@radix-ui/react-avatar": "^1.1.1",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
|
@ -59,25 +59,25 @@
|
|||
"dotenv": "^16.4.7",
|
||||
"elliptic": "6.6.1",
|
||||
"embla-carousel-react": "^8.3.0",
|
||||
"framer-motion": "^11.15.0",
|
||||
"framer-motion": "^11.16.0",
|
||||
"geist": "^1.3.1",
|
||||
"launchdarkly-react-client-sdk": "^3.6.0",
|
||||
"lucide-react": "^0.468.0",
|
||||
"lucide-react": "^0.469.0",
|
||||
"moment": "^2.30.1",
|
||||
"next": "^14.2.13",
|
||||
"next-themes": "^0.4.4",
|
||||
"react": "^18",
|
||||
"react-day-picker": "^9.4.4",
|
||||
"react-day-picker": "^9.5.0",
|
||||
"react-dom": "^18",
|
||||
"react-hook-form": "^7.54.0",
|
||||
"react-hook-form": "^7.54.2",
|
||||
"react-icons": "^5.4.0",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-modal": "^3.16.1",
|
||||
"react-markdown": "^9.0.3",
|
||||
"react-modal": "^3.16.3",
|
||||
"react-shepherd": "^6.1.6",
|
||||
"recharts": "^2.14.1",
|
||||
"tailwind-merge": "^2.5.5",
|
||||
"tailwind-merge": "^2.6.0",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"uuid": "^11.0.3",
|
||||
"uuid": "^11.0.4",
|
||||
"zod": "^3.23.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
@ -94,16 +94,16 @@
|
|||
"@storybook/test": "^8.3.5",
|
||||
"@storybook/test-runner": "^0.21.0",
|
||||
"@types/negotiator": "^0.6.3",
|
||||
"@types/node": "^22.9.0",
|
||||
"@types/node": "^22.10.5",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"@types/react-modal": "^3.16.3",
|
||||
"axe-playwright": "^2.0.3",
|
||||
"chromatic": "^11.12.5",
|
||||
"concurrently": "^9.1.1",
|
||||
"chromatic": "^11.22.0",
|
||||
"concurrently": "^9.1.2",
|
||||
"eslint": "^8",
|
||||
"eslint-config-next": "15.1.3",
|
||||
"eslint-plugin-storybook": "^0.11.0",
|
||||
"eslint-plugin-storybook": "^0.11.2",
|
||||
"msw": "^2.7.0",
|
||||
"msw-storybook-addon": "^2.0.3",
|
||||
"postcss": "^8",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import React from "react";
|
||||
import type { Metadata } from "next";
|
||||
import { Inter } from "next/font/google";
|
||||
import { Inter, Poppins } from "next/font/google";
|
||||
import { Providers } from "@/app/providers";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Navbar } from "@/components/agptui/Navbar";
|
||||
|
@ -10,8 +10,16 @@ import TallyPopupSimple from "@/components/TallyPopup";
|
|||
import { GoogleAnalytics } from "@next/third-parties/google";
|
||||
import { Toaster } from "@/components/ui/toaster";
|
||||
import { IconType } from "@/components/ui/icons";
|
||||
import { GeistSans } from "geist/font/sans";
|
||||
import { GeistMono } from "geist/font/mono";
|
||||
|
||||
const inter = Inter({ subsets: ["latin"] });
|
||||
// Fonts
|
||||
const inter = Inter({ subsets: ["latin"], variable: "--font-inter" });
|
||||
const poppins = Poppins({
|
||||
subsets: ["latin"],
|
||||
weight: ["400", "500", "600", "700"],
|
||||
variable: "--font-poppins",
|
||||
});
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "NextGen AutoGPT",
|
||||
|
@ -24,7 +32,10 @@ export default async function RootLayout({
|
|||
children: React.ReactNode;
|
||||
}>) {
|
||||
return (
|
||||
<html lang="en">
|
||||
<html
|
||||
lang="en"
|
||||
className={`${GeistSans.variable} ${GeistMono.variable} ${poppins.variable} ${inter.variable}`}
|
||||
>
|
||||
<body className={cn("antialiased transition-colors", inter.className)}>
|
||||
<Providers
|
||||
attribute="class"
|
||||
|
|
|
@ -8,6 +8,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
|||
{ text: "Creator Dashboard", href: "/marketplace/dashboard" },
|
||||
{ text: "Agent dashboard", href: "/marketplace/agent-dashboard" },
|
||||
{ text: "Integrations", href: "/marketplace/integrations" },
|
||||
{ text: "API Keys", href: "/marketplace/api_keys" },
|
||||
{ text: "Profile", href: "/marketplace/profile" },
|
||||
{ text: "Settings", href: "/marketplace/settings" },
|
||||
],
|
||||
|
@ -17,7 +18,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
|||
return (
|
||||
<div className="flex min-h-screen w-screen max-w-[1360px] flex-col lg:flex-row">
|
||||
<Sidebar linkGroups={sidebarLinkGroups} />
|
||||
<div className="pl-4">{children}</div>
|
||||
<div className="flex-1 pl-4">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import { APIKeysSection } from "@/components/agptui/composite/APIKeySection";
|
||||
|
||||
const ApiKeysPage = () => {
|
||||
return (
|
||||
<div className="w-full pr-4 pt-24 md:pt-0">
|
||||
<APIKeysSection />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ApiKeysPage;
|
|
@ -253,7 +253,13 @@ export function CustomNode({
|
|||
!isHidden &&
|
||||
(isRequired || isAdvancedOpen || isConnected || !isAdvanced) && (
|
||||
<div key={propKey} data-id={`input-handle-${propKey}`}>
|
||||
{isConnectable ? (
|
||||
{isConnectable &&
|
||||
!(
|
||||
"oneOf" in propSchema &&
|
||||
propSchema.oneOf &&
|
||||
"discriminator" in propSchema &&
|
||||
propSchema.discriminator
|
||||
) ? (
|
||||
<NodeHandle
|
||||
keyName={propKey}
|
||||
isConnected={isConnected}
|
||||
|
|
|
@ -61,7 +61,7 @@ const TallyPopupSimple = () => {
|
|||
<Button
|
||||
variant="default"
|
||||
onClick={resetTutorial}
|
||||
className="font-inter mb-0 h-14 w-28 rounded-2xl bg-[rgba(65,65,64,1)] text-left text-lg font-medium leading-6"
|
||||
className="mb-0 h-14 w-28 rounded-2xl bg-[rgba(65,65,64,1)] text-left font-inter text-lg font-medium leading-6"
|
||||
>
|
||||
Tutorial
|
||||
</Button>
|
||||
|
|
|
@ -2,7 +2,7 @@ import * as React from "react";
|
|||
import Link from "next/link";
|
||||
import { Button } from "./Button";
|
||||
import { Sheet, SheetContent, SheetTrigger } from "@/components/ui/sheet";
|
||||
import { Menu } from "lucide-react";
|
||||
import { KeyIcon, Menu } from "lucide-react";
|
||||
import {
|
||||
IconDashboardLayout,
|
||||
IconIntegrations,
|
||||
|
@ -58,6 +58,15 @@ export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
|
|||
Integrations
|
||||
</div>
|
||||
</Link>
|
||||
<Link
|
||||
href="/marketplace/api_keys"
|
||||
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||
>
|
||||
<KeyIcon className="h-6 w-6" />
|
||||
<div className="p-ui-medium text-base font-medium leading-normal">
|
||||
API Keys
|
||||
</div>
|
||||
</Link>
|
||||
<Link
|
||||
href="/marketplace/profile"
|
||||
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||
|
@ -102,6 +111,15 @@ export const Sidebar: React.FC<SidebarProps> = ({ linkGroups }) => {
|
|||
Integrations
|
||||
</div>
|
||||
</Link>
|
||||
<Link
|
||||
href="/marketplace/api_keys"
|
||||
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||
>
|
||||
<KeyIcon className="h-6 w-6" strokeWidth={1} />
|
||||
<div className="p-ui-medium text-base font-medium leading-normal">
|
||||
API Keys
|
||||
</div>
|
||||
</Link>
|
||||
<Link
|
||||
href="/marketplace/profile"
|
||||
className="inline-flex w-full items-center gap-2.5 rounded-xl px-3 py-3 text-neutral-800 hover:bg-neutral-800 hover:text-white dark:text-neutral-200 dark:hover:bg-neutral-700 dark:hover:text-white"
|
||||
|
|
|
@ -0,0 +1,296 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import { APIKey, APIKeyPermission } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
import { LuCopy } from "react-icons/lu";
|
||||
import { Loader2, MoreVertical } from "lucide-react";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
|
||||
export function APIKeysSection() {
|
||||
const [apiKeys, setApiKeys] = useState<APIKey[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [isCreateOpen, setIsCreateOpen] = useState(false);
|
||||
const [isKeyDialogOpen, setIsKeyDialogOpen] = useState(false);
|
||||
const [newKeyName, setNewKeyName] = useState("");
|
||||
const [newKeyDescription, setNewKeyDescription] = useState("");
|
||||
const [newApiKey, setNewApiKey] = useState("");
|
||||
const [selectedPermissions, setSelectedPermissions] = useState<
|
||||
APIKeyPermission[]
|
||||
>([]);
|
||||
const { toast } = useToast();
|
||||
const api = useBackendAPI();
|
||||
|
||||
useEffect(() => {
|
||||
loadAPIKeys();
|
||||
}, []);
|
||||
|
||||
const loadAPIKeys = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const keys = await api.listAPIKeys();
|
||||
setApiKeys(keys.filter((key) => key.status === "ACTIVE"));
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateKey = async () => {
|
||||
try {
|
||||
const response = await api.createAPIKey(
|
||||
newKeyName,
|
||||
selectedPermissions,
|
||||
newKeyDescription,
|
||||
);
|
||||
|
||||
setNewApiKey(response.plain_text_key);
|
||||
setIsCreateOpen(false);
|
||||
setIsKeyDialogOpen(true);
|
||||
loadAPIKeys();
|
||||
} catch (error) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to create AutoGPT Platform API key",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleCopyKey = () => {
|
||||
navigator.clipboard.writeText(newApiKey);
|
||||
toast({
|
||||
title: "Copied",
|
||||
description: "AutoGPT Platform API key copied to clipboard",
|
||||
});
|
||||
};
|
||||
|
||||
const handleRevokeKey = async (keyId: string) => {
|
||||
try {
|
||||
await api.revokeAPIKey(keyId);
|
||||
toast({
|
||||
title: "Success",
|
||||
description: "AutoGPT Platform API key revoked successfully",
|
||||
});
|
||||
loadAPIKeys();
|
||||
} catch (error) {
|
||||
toast({
|
||||
title: "Error",
|
||||
description: "Failed to revoke AutoGPT Platform API key",
|
||||
variant: "destructive",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>AutoGPT Platform API Keys</CardTitle>
|
||||
<CardDescription>
|
||||
Manage your AutoGPT Platform API keys for programmatic access
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="mb-4 flex justify-end">
|
||||
<Dialog open={isCreateOpen} onOpenChange={setIsCreateOpen}>
|
||||
<DialogTrigger asChild>
|
||||
<Button>Create Key</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create New API Key</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create a new AutoGPT Platform API key
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="name">Name</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={newKeyName}
|
||||
onChange={(e) => setNewKeyName(e.target.value)}
|
||||
placeholder="My AutoGPT Platform API Key"
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="description">Description (Optional)</Label>
|
||||
<Input
|
||||
id="description"
|
||||
value={newKeyDescription}
|
||||
onChange={(e) => setNewKeyDescription(e.target.value)}
|
||||
placeholder="Used for..."
|
||||
/>
|
||||
</div>
|
||||
<div className="grid gap-2">
|
||||
<Label>Permissions</Label>
|
||||
{Object.values(APIKeyPermission).map((permission) => (
|
||||
<div
|
||||
className="flex items-center space-x-2"
|
||||
key={permission}
|
||||
>
|
||||
<Checkbox
|
||||
id={permission}
|
||||
checked={selectedPermissions.includes(permission)}
|
||||
onCheckedChange={(checked) => {
|
||||
setSelectedPermissions(
|
||||
checked
|
||||
? [...selectedPermissions, permission]
|
||||
: selectedPermissions.filter(
|
||||
(p) => p !== permission,
|
||||
),
|
||||
);
|
||||
}}
|
||||
/>
|
||||
<Label htmlFor={permission}>{permission}</Label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => setIsCreateOpen(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleCreateKey}>Create</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={isKeyDialogOpen} onOpenChange={setIsKeyDialogOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>AutoGPT Platform API Key Created</DialogTitle>
|
||||
<DialogDescription>
|
||||
Please copy your AutoGPT API key now. You won't be able
|
||||
to see it again!
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 rounded-md bg-secondary p-2 text-sm">
|
||||
{newApiKey}
|
||||
</code>
|
||||
<Button size="icon" variant="outline" onClick={handleCopyKey}>
|
||||
<LuCopy className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button onClick={() => setIsKeyDialogOpen(false)}>Close</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center p-4">
|
||||
<Loader2 className="h-6 w-6 animate-spin" />
|
||||
</div>
|
||||
) : (
|
||||
apiKeys.length > 0 && (
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>API Key</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Created</TableHead>
|
||||
<TableHead>Last Used</TableHead>
|
||||
<TableHead></TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{apiKeys.map((key) => (
|
||||
<TableRow key={key.id}>
|
||||
<TableCell>{key.name}</TableCell>
|
||||
<TableCell>
|
||||
<div className="rounded-md border p-1 px-2 text-xs">
|
||||
{`${key.prefix}******************${key.postfix}`}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Badge
|
||||
variant={
|
||||
key.status === "ACTIVE" ? "default" : "destructive"
|
||||
}
|
||||
className={
|
||||
key.status === "ACTIVE"
|
||||
? "border-green-600 bg-green-100 text-green-800"
|
||||
: "border-red-600 bg-red-100 text-red-800"
|
||||
}
|
||||
>
|
||||
{key.status}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{new Date(key.created_at).toLocaleDateString()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{key.last_used_at
|
||||
? new Date(key.last_used_at).toLocaleDateString()
|
||||
: "Never"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="sm">
|
||||
<MoreVertical className="h-4 w-4" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
<DropdownMenuItem
|
||||
className="text-destructive"
|
||||
onClick={() => handleRevokeKey(key.id)}
|
||||
>
|
||||
Revoke
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
)
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
|
@ -6,9 +6,11 @@ import {
|
|||
Carousel,
|
||||
CarouselContent,
|
||||
CarouselItem,
|
||||
CarouselPrevious,
|
||||
CarouselNext,
|
||||
CarouselIndicator,
|
||||
} from "@/components/ui/carousel";
|
||||
import { useCallback, useState } from "react";
|
||||
import { IconLeftArrow, IconRightArrow } from "@/components/ui/icons";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
const BACKGROUND_COLORS = [
|
||||
|
@ -63,27 +65,24 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
|
|||
|
||||
return (
|
||||
<div className="flex w-full flex-col items-center justify-center">
|
||||
<div className="w-full">
|
||||
<h2 className="font-poppins mb-8 text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
|
||||
<div className="w-[99vw]">
|
||||
<h2 className="font-poppins mx-auto mb-8 max-w-[1360px] px-4 text-2xl font-semibold leading-7 text-neutral-800 dark:text-neutral-200">
|
||||
Featured agents
|
||||
</h2>
|
||||
|
||||
<div>
|
||||
<div className="w-[99vw] pb-[60px]">
|
||||
<Carousel
|
||||
className="mx-auto pb-10"
|
||||
opts={{
|
||||
loop: true,
|
||||
startIndex: currentSlide,
|
||||
duration: 500,
|
||||
align: "start",
|
||||
align: "center",
|
||||
containScroll: "trimSnaps",
|
||||
}}
|
||||
className="w-full overflow-x-hidden"
|
||||
>
|
||||
<CarouselContent className="transition-transform duration-500">
|
||||
<CarouselContent className="ml-[calc(50vw-690px)]">
|
||||
{featuredAgents.map((agent, index) => (
|
||||
<CarouselItem
|
||||
key={index}
|
||||
className="max-w-[460px] flex-[0_0_auto] pr-8"
|
||||
className="max-w-[460px] flex-[0_0_auto]"
|
||||
>
|
||||
<FeaturedStoreCard
|
||||
agentName={agent.agent_name}
|
||||
|
@ -99,37 +98,13 @@ export const FeaturedSection: React.FC<FeaturedSectionProps> = ({
|
|||
</CarouselItem>
|
||||
))}
|
||||
</CarouselContent>
|
||||
<div className="relative mx-auto w-full max-w-[1360px] pl-4">
|
||||
<CarouselIndicator />
|
||||
<CarouselPrevious afterClick={handlePrevSlide} />
|
||||
<CarouselNext afterClick={handleNextSlide} />
|
||||
</div>
|
||||
</Carousel>
|
||||
</div>
|
||||
|
||||
<div className="mt-8 flex w-full items-center justify-between">
|
||||
<div className="flex h-3 items-center gap-2">
|
||||
{featuredAgents.map((_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`${
|
||||
currentSlide === index
|
||||
? "h-3 w-[52px] rounded-[39px] bg-neutral-800 transition-all duration-500 dark:bg-neutral-200"
|
||||
: "h-3 w-3 rounded-full bg-neutral-300 transition-all duration-500 dark:bg-neutral-600"
|
||||
}`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
<div className="mb-[60px] flex items-center gap-3">
|
||||
<button
|
||||
onClick={handlePrevSlide}
|
||||
className="mb:h-12 mb:w-12 flex h-10 w-10 items-center justify-center rounded-full border border-neutral-400 bg-white dark:border-neutral-600 dark:bg-neutral-800"
|
||||
>
|
||||
<IconLeftArrow className="h-8 w-8 text-neutral-800 dark:text-neutral-200" />
|
||||
</button>
|
||||
<button
|
||||
onClick={handleNextSlide}
|
||||
className="mb:h-12 mb:w-12 flex h-10 w-10 items-center justify-center rounded-full border border-neutral-900 bg-white dark:border-neutral-600 dark:bg-neutral-800"
|
||||
>
|
||||
<IconRightArrow className="h-8 w-8 text-neutral-800 dark:text-neutral-200" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
@ -7,8 +7,15 @@ import SchemaTooltip from "@/components/SchemaTooltip";
|
|||
import useCredentials from "@/hooks/useCredentials";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { NotionLogoIcon } from "@radix-ui/react-icons";
|
||||
import { FaDiscord, FaGithub, FaGoogle, FaMedium, FaKey } from "react-icons/fa";
|
||||
import { FC, useState } from "react";
|
||||
import {
|
||||
FaDiscord,
|
||||
FaGithub,
|
||||
FaTwitter,
|
||||
FaGoogle,
|
||||
FaMedium,
|
||||
FaKey,
|
||||
} from "react-icons/fa";
|
||||
import { FC, useMemo, useState } from "react";
|
||||
import {
|
||||
CredentialsMetaInput,
|
||||
CredentialsProviderName,
|
||||
|
@ -53,6 +60,7 @@ export const providerIcons: Record<
|
|||
google: FaGoogle,
|
||||
groq: fallbackIcon,
|
||||
notion: NotionLogoIcon,
|
||||
nvidia: fallbackIcon,
|
||||
discord: FaDiscord,
|
||||
d_id: fallbackIcon,
|
||||
google_maps: FaGoogle,
|
||||
|
@ -68,6 +76,7 @@ export const providerIcons: Record<
|
|||
replicate: fallbackIcon,
|
||||
fal: fallbackIcon,
|
||||
revid: fallbackIcon,
|
||||
twitter: FaTwitter,
|
||||
unreal_speech: fallbackIcon,
|
||||
exa: fallbackIcon,
|
||||
hubspot: fallbackIcon,
|
||||
|
|
|
@ -28,6 +28,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
|||
jina: "Jina",
|
||||
medium: "Medium",
|
||||
notion: "Notion",
|
||||
nvidia: "Nvidia",
|
||||
ollama: "Ollama",
|
||||
openai: "OpenAI",
|
||||
openweathermap: "OpenWeatherMap",
|
||||
|
@ -37,6 +38,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
|||
replicate: "Replicate",
|
||||
fal: "FAL",
|
||||
revid: "Rev.ID",
|
||||
twitter: "Twitter",
|
||||
unreal_speech: "Unreal Speech",
|
||||
exa: "Exa",
|
||||
hubspot: "Hubspot",
|
||||
|
|
|
@ -17,6 +17,7 @@ import {
|
|||
BlockIOStringSubSchema,
|
||||
BlockIONumberSubSchema,
|
||||
BlockIOBooleanSubSchema,
|
||||
BlockIOSimpleTypeSubSchema,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import React, { FC, useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { Button } from "./ui/button";
|
||||
|
@ -40,6 +41,7 @@ import { LocalValuedInput } from "./ui/input";
|
|||
import NodeHandle from "./NodeHandle";
|
||||
import { ConnectionData } from "./CustomNode";
|
||||
import { CredentialsInput } from "./integrations/credentials-input";
|
||||
import { MultiSelect } from "./ui/multiselect-input";
|
||||
|
||||
type NodeObjectInputTreeProps = {
|
||||
nodeId: string;
|
||||
|
@ -101,6 +103,92 @@ const NodeObjectInputTree: FC<NodeObjectInputTreeProps> = ({
|
|||
|
||||
export default NodeObjectInputTree;
|
||||
|
||||
const NodeImageInput: FC<{
|
||||
selfKey: string;
|
||||
schema: BlockIOStringSubSchema;
|
||||
value?: string;
|
||||
error?: string;
|
||||
handleInputChange: NodeObjectInputTreeProps["handleInputChange"];
|
||||
className?: string;
|
||||
displayName: string;
|
||||
}> = ({
|
||||
selfKey,
|
||||
schema,
|
||||
value = "",
|
||||
error,
|
||||
handleInputChange,
|
||||
className,
|
||||
displayName,
|
||||
}) => {
|
||||
const handleFileChange = useCallback(
|
||||
async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (!file) return;
|
||||
|
||||
// Validate file type
|
||||
if (!file.type.startsWith("image/")) {
|
||||
console.error("Please upload an image file");
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert to base64
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
const base64String = (e.target?.result as string).split(",")[1];
|
||||
handleInputChange(selfKey, base64String);
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
},
|
||||
[selfKey, handleInputChange],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col gap-2", className)}>
|
||||
<div className="nodrag flex flex-col gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() =>
|
||||
document.getElementById(`${selfKey}-upload`)?.click()
|
||||
}
|
||||
className="w-full"
|
||||
>
|
||||
{value ? "Change Image" : `Upload ${displayName}`}
|
||||
</Button>
|
||||
{value && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="text-red-500 hover:text-red-700"
|
||||
onClick={() => handleInputChange(selfKey, "")}
|
||||
>
|
||||
<Cross2Icon className="h-4 w-4" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<input
|
||||
id={`${selfKey}-upload`}
|
||||
type="file"
|
||||
accept="image/*"
|
||||
onChange={handleFileChange}
|
||||
className="hidden"
|
||||
/>
|
||||
|
||||
{value && (
|
||||
<div className="relative mt-2 rounded-md border border-gray-300 p-2 dark:border-gray-600">
|
||||
<img
|
||||
src={`data:image/jpeg;base64,${value}`}
|
||||
alt="Preview"
|
||||
className="max-h-32 w-full rounded-md object-contain"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{error && <span className="error-message">{error}</span>}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const NodeDateTimeInput: FC<{
|
||||
selfKey: string;
|
||||
schema: BlockIOStringSubSchema;
|
||||
|
@ -225,6 +313,8 @@ export const NodeGenericInputField: FC<{
|
|||
);
|
||||
}
|
||||
|
||||
console.log("propSchema", propSchema);
|
||||
|
||||
if ("properties" in propSchema) {
|
||||
// Render a multi-select for all-boolean sub-schemas with more than 3 properties
|
||||
if (
|
||||
|
@ -290,12 +380,53 @@ export const NodeGenericInputField: FC<{
|
|||
}
|
||||
|
||||
if ("anyOf" in propSchema) {
|
||||
// Optional oneOf
|
||||
if (
|
||||
"oneOf" in propSchema.anyOf[0] &&
|
||||
propSchema.anyOf[0].oneOf &&
|
||||
"discriminator" in propSchema.anyOf[0] &&
|
||||
propSchema.anyOf[0].discriminator
|
||||
) {
|
||||
return (
|
||||
<NodeOneOfDiscriminatorField
|
||||
nodeId={nodeId}
|
||||
propKey={propKey}
|
||||
propSchema={propSchema.anyOf[0]}
|
||||
currentValue={currentValue}
|
||||
errors={errors}
|
||||
connections={connections}
|
||||
handleInputChange={handleInputChange}
|
||||
handleInputClick={handleInputClick}
|
||||
className={className}
|
||||
displayName={displayName}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// optional items
|
||||
const types = propSchema.anyOf.map((s) =>
|
||||
"type" in s ? s.type : undefined,
|
||||
);
|
||||
if (types.includes("string") && types.includes("null")) {
|
||||
// optional string
|
||||
// optional string and datetime
|
||||
|
||||
if (
|
||||
"format" in propSchema.anyOf[0] &&
|
||||
propSchema.anyOf[0].format === "date-time"
|
||||
) {
|
||||
return (
|
||||
<NodeDateTimeInput
|
||||
selfKey={propKey}
|
||||
schema={propSchema.anyOf[0]}
|
||||
value={currentValue}
|
||||
error={errors[propKey]}
|
||||
className={className}
|
||||
displayName={displayName}
|
||||
handleInputChange={handleInputChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<NodeStringInput
|
||||
selfKey={propKey}
|
||||
|
@ -356,6 +487,42 @@ export const NodeGenericInputField: FC<{
|
|||
/>
|
||||
);
|
||||
} else if (types.includes("object") && types.includes("null")) {
|
||||
// rendering optional mutliselect
|
||||
if (
|
||||
Object.values(
|
||||
(propSchema.anyOf[0] as BlockIOObjectSubSchema).properties,
|
||||
).every(
|
||||
(subSchema) => "type" in subSchema && subSchema.type == "boolean",
|
||||
) &&
|
||||
Object.keys((propSchema.anyOf[0] as BlockIOObjectSubSchema).properties)
|
||||
.length >= 1
|
||||
) {
|
||||
const options = Object.keys(
|
||||
(propSchema.anyOf[0] as BlockIOObjectSubSchema).properties,
|
||||
);
|
||||
const selectedKeys = Object.entries(currentValue || {})
|
||||
.filter(([_, v]) => v)
|
||||
.map(([k, _]) => k);
|
||||
return (
|
||||
<NodeMultiSelectInput
|
||||
selfKey={propKey}
|
||||
schema={propSchema.anyOf[0] as BlockIOObjectSubSchema}
|
||||
selection={selectedKeys}
|
||||
error={errors[propKey]}
|
||||
className={className}
|
||||
displayName={displayName}
|
||||
handleInputChange={(key, selection) => {
|
||||
handleInputChange(
|
||||
key,
|
||||
Object.fromEntries(
|
||||
options.map((option) => [option, selection.includes(option)]),
|
||||
),
|
||||
);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<NodeKeyValueInput
|
||||
nodeId={nodeId}
|
||||
|
@ -418,6 +585,19 @@ export const NodeGenericInputField: FC<{
|
|||
|
||||
switch (propSchema.type) {
|
||||
case "string":
|
||||
if ("image_upload" in propSchema && propSchema.image_upload === true) {
|
||||
return (
|
||||
<NodeImageInput
|
||||
selfKey={propKey}
|
||||
schema={propSchema}
|
||||
value={currentValue}
|
||||
error={errors[propKey]}
|
||||
className={className}
|
||||
displayName={displayName}
|
||||
handleInputChange={handleInputChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if ("format" in propSchema && propSchema.format === "date-time") {
|
||||
return (
|
||||
<NodeDateTimeInput
|
||||
|
@ -523,7 +703,7 @@ const NodeOneOfDiscriminatorField: FC<{
|
|||
propSchema: any;
|
||||
currentValue?: any;
|
||||
errors: { [key: string]: string | undefined };
|
||||
connections: any;
|
||||
connections: ConnectionData;
|
||||
handleInputChange: (key: string, value: any) => void;
|
||||
handleInputClick: (key: string) => void;
|
||||
className?: string;
|
||||
|
@ -538,7 +718,6 @@ const NodeOneOfDiscriminatorField: FC<{
|
|||
handleInputChange,
|
||||
handleInputClick,
|
||||
className,
|
||||
displayName,
|
||||
}) => {
|
||||
const discriminator = propSchema.discriminator;
|
||||
|
||||
|
@ -554,7 +733,7 @@ const NodeOneOfDiscriminatorField: FC<{
|
|||
|
||||
return {
|
||||
value: variantDiscValue,
|
||||
schema: variant,
|
||||
schema: variant as BlockIOSubSchema,
|
||||
};
|
||||
})
|
||||
.filter((v: any) => v.value != null);
|
||||
|
@ -585,8 +764,24 @@ const NodeOneOfDiscriminatorField: FC<{
|
|||
(opt: any) => opt.value === chosenType,
|
||||
)?.schema;
|
||||
|
||||
function getEntryKey(key: string): string {
|
||||
// use someKey for handle purpose (not childKey)
|
||||
return `${propKey}_#_${key}`;
|
||||
}
|
||||
|
||||
function isConnected(key: string): boolean {
|
||||
return connections.some(
|
||||
(c) => c.targetHandle === getEntryKey(key) && c.target === nodeId,
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col space-y-2", className)}>
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-w-[400px] max-w-[95%] flex-col space-y-4",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<Select value={chosenType || ""} onValueChange={handleVariantChange}>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select a type..." />
|
||||
|
@ -607,32 +802,36 @@ const NodeOneOfDiscriminatorField: FC<{
|
|||
if (someKey === "discriminator") {
|
||||
return null;
|
||||
}
|
||||
const childKey = propKey ? `${propKey}.${someKey}` : someKey;
|
||||
const childKey = propKey ? `${propKey}.${someKey}` : someKey; // for history redo/undo purpose
|
||||
return (
|
||||
<div
|
||||
key={childKey}
|
||||
className="flex w-full flex-row justify-between space-y-2"
|
||||
className="mb-4 flex w-full flex-col justify-between space-y-2"
|
||||
>
|
||||
<span className="mr-2 mt-3 dark:text-gray-300">
|
||||
{(childSchema as BlockIOSubSchema).title ||
|
||||
beautifyString(someKey)}
|
||||
</span>
|
||||
<NodeGenericInputField
|
||||
nodeId={nodeId}
|
||||
key={propKey}
|
||||
propKey={childKey}
|
||||
propSchema={childSchema as BlockIOSubSchema}
|
||||
currentValue={
|
||||
currentValue ? currentValue[someKey] : undefined
|
||||
}
|
||||
errors={errors}
|
||||
connections={connections}
|
||||
handleInputChange={handleInputChange}
|
||||
handleInputClick={handleInputClick}
|
||||
displayName={
|
||||
chosenVariantSchema.title || beautifyString(someKey)
|
||||
}
|
||||
<NodeHandle
|
||||
keyName={getEntryKey(someKey)}
|
||||
schema={childSchema as BlockIOSubSchema}
|
||||
isConnected={isConnected(getEntryKey(someKey))}
|
||||
isRequired={false}
|
||||
side="left"
|
||||
/>
|
||||
|
||||
{!isConnected(someKey) && (
|
||||
<NodeGenericInputField
|
||||
nodeId={nodeId}
|
||||
key={propKey}
|
||||
propKey={childKey}
|
||||
propSchema={childSchema as BlockIOSubSchema}
|
||||
currentValue={
|
||||
currentValue ? currentValue[someKey] : undefined
|
||||
}
|
||||
errors={errors}
|
||||
connections={connections}
|
||||
handleInputChange={handleInputChange}
|
||||
handleInputClick={handleInputClick}
|
||||
displayName={beautifyString(someKey)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
|
@ -827,6 +1026,13 @@ const NodeKeyValueInput: FC<{
|
|||
);
|
||||
};
|
||||
|
||||
// Checking if schema is type of string
|
||||
function isStringSubSchema(
|
||||
schema: BlockIOSimpleTypeSubSchema,
|
||||
): schema is BlockIOStringSubSchema {
|
||||
return "type" in schema && schema.type === "string";
|
||||
}
|
||||
|
||||
const NodeArrayInput: FC<{
|
||||
nodeId: string;
|
||||
selfKey: string;
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
// This file has been updated for the Store's "Featured Agent Section". If you want to add Carousel, keep these components in mind: CarouselIndicator, CarouselPrevious, and CarouselNext.
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import useEmblaCarousel, {
|
||||
type UseEmblaCarouselType,
|
||||
} from "embla-carousel-react";
|
||||
import { ArrowLeft, ArrowRight } from "lucide-react";
|
||||
import { ChevronLeft, ChevronRight } from "lucide-react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
@ -196,67 +197,137 @@ CarouselItem.displayName = "CarouselItem";
|
|||
|
||||
const CarouselPrevious = React.forwardRef<
|
||||
HTMLButtonElement,
|
||||
React.ComponentProps<typeof Button>
|
||||
>(({ className, variant = "outline", size = "icon", ...props }, ref) => {
|
||||
const { orientation, scrollPrev, canScrollPrev } = useCarousel();
|
||||
React.ComponentProps<typeof Button> & { afterClick?: () => void }
|
||||
>(
|
||||
(
|
||||
{ className, afterClick, variant = "outline", size = "icon", ...props },
|
||||
ref,
|
||||
) => {
|
||||
const { orientation, scrollPrev, canScrollPrev } = useCarousel();
|
||||
|
||||
return (
|
||||
<Button
|
||||
ref={ref}
|
||||
variant={variant}
|
||||
size={size}
|
||||
className={cn(
|
||||
"absolute h-8 w-8 rounded-full",
|
||||
orientation === "horizontal"
|
||||
? "-left-12 top-1/2 -translate-y-1/2"
|
||||
: "-top-12 left-1/2 -translate-x-1/2 rotate-90",
|
||||
className,
|
||||
)}
|
||||
disabled={!canScrollPrev}
|
||||
onClick={scrollPrev}
|
||||
{...props}
|
||||
>
|
||||
<ArrowLeft className="h-4 w-4" />
|
||||
<span className="sr-only">Previous slide</span>
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
return (
|
||||
<Button
|
||||
ref={ref}
|
||||
variant={variant}
|
||||
size={size}
|
||||
className={cn(
|
||||
"absolute h-[52px] w-[52px] rounded-full",
|
||||
orientation === "horizontal"
|
||||
? "-bottom-20 right-24 -translate-y-1/2"
|
||||
: "-top-12 left-1/2 -translate-x-1/2 rotate-90",
|
||||
className,
|
||||
)}
|
||||
disabled={!canScrollPrev}
|
||||
onClick={() => {
|
||||
scrollPrev();
|
||||
if (afterClick) {
|
||||
afterClick();
|
||||
}
|
||||
}}
|
||||
{...props}
|
||||
>
|
||||
<ChevronLeft className="h-8 w-8" strokeWidth={1.25} />
|
||||
<span className="sr-only">Previous slide</span>
|
||||
</Button>
|
||||
);
|
||||
},
|
||||
);
|
||||
CarouselPrevious.displayName = "CarouselPrevious";
|
||||
|
||||
const CarouselNext = React.forwardRef<
|
||||
HTMLButtonElement,
|
||||
React.ComponentProps<typeof Button>
|
||||
>(({ className, variant = "outline", size = "icon", ...props }, ref) => {
|
||||
const { orientation, scrollNext, canScrollNext } = useCarousel();
|
||||
React.ComponentProps<typeof Button> & { afterClick?: () => void }
|
||||
>(
|
||||
(
|
||||
{ className, afterClick, variant = "outline", size = "icon", ...props },
|
||||
ref,
|
||||
) => {
|
||||
const { orientation, scrollNext, canScrollNext } = useCarousel();
|
||||
|
||||
const handleClick = () => {
|
||||
scrollNext();
|
||||
if (afterClick) {
|
||||
afterClick();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Button
|
||||
ref={ref}
|
||||
variant={variant}
|
||||
size={size}
|
||||
className={cn(
|
||||
"absolute h-[52px] w-[52px] rounded-full",
|
||||
orientation === "horizontal"
|
||||
? "-bottom-20 right-4 -translate-y-1/2"
|
||||
: "-bottom-12 left-1/2 -translate-x-1/2 rotate-90",
|
||||
className,
|
||||
)}
|
||||
disabled={!canScrollNext}
|
||||
onClick={handleClick}
|
||||
{...props}
|
||||
>
|
||||
<ChevronRight className="h-8 w-8" strokeWidth={1.25} />
|
||||
<span className="sr-only">Next slide</span>
|
||||
</Button>
|
||||
);
|
||||
},
|
||||
);
|
||||
CarouselNext.displayName = "CarouselNext";
|
||||
|
||||
const CarouselIndicator = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => {
|
||||
const { api } = useCarousel();
|
||||
const [selectedIndex, setSelectedIndex] = React.useState(0);
|
||||
const [scrollSnaps, setScrollSnaps] = React.useState<number[]>([]);
|
||||
|
||||
const scrollTo = React.useCallback(
|
||||
(index: number) => {
|
||||
api?.scrollTo(index);
|
||||
},
|
||||
[api],
|
||||
);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (!api) return;
|
||||
|
||||
setScrollSnaps(api.scrollSnapList());
|
||||
api.on("select", () => {
|
||||
setSelectedIndex(api.selectedScrollSnap());
|
||||
});
|
||||
}, [api]);
|
||||
|
||||
return (
|
||||
<Button
|
||||
<div
|
||||
ref={ref}
|
||||
variant={variant}
|
||||
size={size}
|
||||
className={cn(
|
||||
"absolute h-8 w-8 rounded-full",
|
||||
orientation === "horizontal"
|
||||
? "-right-12 top-1/2 -translate-y-1/2"
|
||||
: "-bottom-12 left-1/2 -translate-x-1/2 rotate-90",
|
||||
className,
|
||||
)}
|
||||
disabled={!canScrollNext}
|
||||
onClick={scrollNext}
|
||||
className={cn("relative top-10 flex h-3 items-center gap-2", className)}
|
||||
{...props}
|
||||
>
|
||||
<ArrowRight className="h-4 w-4" />
|
||||
<span className="sr-only">Next slide</span>
|
||||
</Button>
|
||||
{scrollSnaps.map((_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
onClick={() => scrollTo(index)}
|
||||
className={cn(
|
||||
selectedIndex === index
|
||||
? "h-3 w-[52px] rounded-[39px] bg-neutral-800 transition-all duration-500 dark:bg-neutral-200"
|
||||
: "h-3 w-3 rounded-full bg-neutral-300 transition-all duration-500 dark:bg-neutral-600",
|
||||
"cursor-pointer",
|
||||
)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
CarouselNext.displayName = "CarouselNext";
|
||||
CarouselIndicator.displayName = "CarouselIndicator";
|
||||
|
||||
export {
|
||||
type CarouselApi,
|
||||
Carousel,
|
||||
CarouselContent,
|
||||
CarouselItem,
|
||||
CarouselIndicator,
|
||||
CarouselPrevious,
|
||||
CarouselNext,
|
||||
};
|
||||
|
|
|
@ -17,6 +17,9 @@ const isValidVideoUrl = (url: string): boolean => {
|
|||
};
|
||||
|
||||
const isValidImageUrl = (url: string): boolean => {
|
||||
if (url.startsWith("data:image/")) {
|
||||
return true;
|
||||
}
|
||||
const imageExtensions = /\.(jpeg|jpg|gif|png|svg|webp)$/i;
|
||||
const cleanedUrl = url.split("?")[0];
|
||||
return imageExtensions.test(cleanedUrl);
|
||||
|
@ -50,19 +53,21 @@ const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => {
|
|||
);
|
||||
};
|
||||
|
||||
const ImageRenderer: React.FC<{ imageUrl: string }> = ({ imageUrl }) => (
|
||||
<div className="w-full p-2">
|
||||
<picture>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image"
|
||||
className="h-auto max-w-full"
|
||||
width="100%"
|
||||
height="auto"
|
||||
/>
|
||||
</picture>
|
||||
</div>
|
||||
);
|
||||
const ImageRenderer: React.FC<{ imageUrl: string }> = ({ imageUrl }) => {
|
||||
return (
|
||||
<div className="w-full p-2">
|
||||
<picture>
|
||||
<img
|
||||
src={imageUrl}
|
||||
alt="Image"
|
||||
className="h-auto max-w-full"
|
||||
width="100%"
|
||||
height="auto"
|
||||
/>
|
||||
</picture>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const AudioRenderer: React.FC<{ audioUrl: string }> = ({ audioUrl }) => (
|
||||
<div className="w-full p-2">
|
||||
|
@ -92,6 +97,9 @@ export const ContentRenderer: React.FC<{
|
|||
truncateLongData?: boolean;
|
||||
}> = ({ value, truncateLongData }) => {
|
||||
if (typeof value === "string") {
|
||||
if (value.startsWith("data:image/")) {
|
||||
return <ImageRenderer imageUrl={value} />;
|
||||
}
|
||||
if (isValidVideoUrl(value)) {
|
||||
return <VideoRenderer videoUrl={value} />;
|
||||
} else if (isValidImageUrl(value)) {
|
||||
|
|
|
@ -29,6 +29,9 @@ import {
|
|||
StoreReview,
|
||||
ScheduleCreatable,
|
||||
Schedule,
|
||||
APIKeyPermission,
|
||||
CreateAPIKeyResponse,
|
||||
APIKey,
|
||||
} from "./types";
|
||||
import { createBrowserClient } from "@supabase/ssr";
|
||||
import getServerSupabase from "../supabase/getServerSupabase";
|
||||
|
@ -221,6 +224,36 @@ export default class BackendAPI {
|
|||
);
|
||||
}
|
||||
|
||||
// API Key related requests
|
||||
async createAPIKey(
|
||||
name: string,
|
||||
permissions: APIKeyPermission[],
|
||||
description?: string,
|
||||
): Promise<CreateAPIKeyResponse> {
|
||||
return this._request("POST", "/api-keys", {
|
||||
name,
|
||||
permissions,
|
||||
description,
|
||||
});
|
||||
}
|
||||
|
||||
async listAPIKeys(): Promise<APIKey[]> {
|
||||
return this._get("/api-keys");
|
||||
}
|
||||
|
||||
async revokeAPIKey(keyId: string): Promise<APIKey> {
|
||||
return this._request("DELETE", `/api-keys/${keyId}`);
|
||||
}
|
||||
|
||||
async updateAPIKeyPermissions(
|
||||
keyId: string,
|
||||
permissions: APIKeyPermission[],
|
||||
): Promise<APIKey> {
|
||||
return this._request("PUT", `/api-keys/${keyId}/permissions`, {
|
||||
permissions,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns `true` if a ping event was received, `false` if provider doesn't support pinging but the webhook exists.
|
||||
* @throws `Error` if the webhook does not exist.
|
||||
|
|
|
@ -41,7 +41,7 @@ export type BlockIOSubSchema =
|
|||
| BlockIOSimpleTypeSubSchema
|
||||
| BlockIOCombinedTypeSubSchema;
|
||||
|
||||
type BlockIOSimpleTypeSubSchema =
|
||||
export type BlockIOSimpleTypeSubSchema =
|
||||
| BlockIOObjectSubSchema
|
||||
| BlockIOCredentialsSubSchema
|
||||
| BlockIOKVSubSchema
|
||||
|
@ -113,6 +113,7 @@ export const PROVIDER_NAMES = {
|
|||
JINA: "jina",
|
||||
MEDIUM: "medium",
|
||||
NOTION: "notion",
|
||||
NVIDIA: "nvidia",
|
||||
OLLAMA: "ollama",
|
||||
OPENAI: "openai",
|
||||
OPENWEATHERMAP: "openweathermap",
|
||||
|
@ -125,6 +126,7 @@ export const PROVIDER_NAMES = {
|
|||
UNREAL_SPEECH: "unreal_speech",
|
||||
EXA: "exa",
|
||||
HUBSPOT: "hubspot",
|
||||
TWITTER: "twitter",
|
||||
} as const;
|
||||
// --8<-- [end:BlockIOCredentialsSubSchema]
|
||||
|
||||
|
@ -512,3 +514,36 @@ export type StoreReviewCreate = {
|
|||
score: number;
|
||||
comments?: string;
|
||||
};
|
||||
|
||||
// API Key Types
|
||||
|
||||
export enum APIKeyPermission {
|
||||
EXECUTE_GRAPH = "EXECUTE_GRAPH",
|
||||
READ_GRAPH = "READ_GRAPH",
|
||||
EXECUTE_BLOCK = "EXECUTE_BLOCK",
|
||||
READ_BLOCK = "READ_BLOCK",
|
||||
}
|
||||
|
||||
export enum APIKeyStatus {
|
||||
ACTIVE = "ACTIVE",
|
||||
REVOKED = "REVOKED",
|
||||
SUSPENDED = "SUSPENDED",
|
||||
}
|
||||
|
||||
export interface APIKey {
|
||||
id: string;
|
||||
name: string;
|
||||
prefix: string;
|
||||
postfix: string;
|
||||
status: APIKeyStatus;
|
||||
permissions: APIKeyPermission[];
|
||||
created_at: string;
|
||||
last_used_at?: string;
|
||||
revoked_at?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface CreateAPIKeyResponse {
|
||||
api_key: APIKey;
|
||||
plain_text_key: string;
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue