Compare commits
33 Commits
master
...
ntindle/op
Author | SHA1 | Date |
---|---|---|
![]() |
2fd8c8d261 | |
![]() |
fd6f28fa57 | |
![]() |
4b17cc9963 | |
![]() |
00bb7c67b3 | |
![]() |
6b31356264 | |
![]() |
a88c865437 | |
![]() |
287aa819bb | |
![]() |
db21c6d4bc | |
![]() |
59dd75d016 | |
![]() |
38761f6706 | |
![]() |
513e4eae4b | |
![]() |
fec9d348a0 | |
![]() |
75634e6155 | |
![]() |
6cf77c264a | |
![]() |
2b5c94d508 | |
![]() |
1c6b33d9fb | |
![]() |
d4692f33e2 | |
![]() |
d7a9563d49 | |
![]() |
2ea61f8b65 | |
![]() |
2a5f3d167d | |
![]() |
c0a5a01311 | |
![]() |
0aee309f72 | |
![]() |
4c07f6c633 | |
![]() |
c39f27bcd4 | |
![]() |
35dcc6a2a1 | |
![]() |
bef5637f29 | |
![]() |
e933502cbd | |
![]() |
5720225a75 | |
![]() |
cb3808cb78 | |
![]() |
b6b97f10b8 | |
![]() |
0a905c6d66 | |
![]() |
6b3f5b413f | |
![]() |
8d79a62f61 |
|
@ -82,10 +82,13 @@ GROQ_API_KEY=
|
||||||
OPEN_ROUTER_API_KEY=
|
OPEN_ROUTER_API_KEY=
|
||||||
|
|
||||||
# Reddit
|
# Reddit
|
||||||
|
# Reddit
|
||||||
|
# Go to https://www.reddit.com/prefs/apps and create a new app
|
||||||
|
# Choose "script" for the type
|
||||||
|
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||||
REDDIT_CLIENT_ID=
|
REDDIT_CLIENT_ID=
|
||||||
REDDIT_CLIENT_SECRET=
|
REDDIT_CLIENT_SECRET=
|
||||||
REDDIT_USERNAME=
|
REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
|
||||||
REDDIT_PASSWORD=
|
|
||||||
|
|
||||||
# Discord
|
# Discord
|
||||||
DISCORD_BOT_TOKEN=
|
DISCORD_BOT_TOKEN=
|
||||||
|
|
|
@ -1,22 +1,53 @@
|
||||||
import smtplib
|
import smtplib
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="smtp",
|
||||||
|
username=SecretStr("mock-smtp-username"),
|
||||||
|
password=SecretStr("mock-smtp-password"),
|
||||||
|
title="Mock SMTP credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
SMTPCredentials = UserPasswordCredentials
|
||||||
|
SMTPCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.SMTP],
|
||||||
|
Literal["user_password"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class EmailCredentials(BaseModel):
|
def SMTPCredentialsField() -> SMTPCredentialsInput:
|
||||||
|
return CredentialsField(
|
||||||
|
description="The SMTP integration requires a username and password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SmtpConfig(BaseModel):
|
||||||
smtp_server: str = SchemaField(
|
smtp_server: str = SchemaField(
|
||||||
default="smtp.gmail.com", description="SMTP server address"
|
default="smtp.gmail.com", description="SMTP server address"
|
||||||
)
|
)
|
||||||
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
smtp_port: int = SchemaField(default=25, description="SMTP port number")
|
||||||
smtp_username: BlockSecret = SecretField(key="smtp_username")
|
|
||||||
smtp_password: BlockSecret = SecretField(key="smtp_password")
|
|
||||||
|
|
||||||
model_config = ConfigDict(title="Email Credentials")
|
model_config = ConfigDict(title="SMTP Config")
|
||||||
|
|
||||||
|
|
||||||
class SendEmailBlock(Block):
|
class SendEmailBlock(Block):
|
||||||
|
@ -30,10 +61,11 @@ class SendEmailBlock(Block):
|
||||||
body: str = SchemaField(
|
body: str = SchemaField(
|
||||||
description="Body of the email", placeholder="Enter the email body"
|
description="Body of the email", placeholder="Enter the email body"
|
||||||
)
|
)
|
||||||
creds: EmailCredentials = SchemaField(
|
config: SmtpConfig = SchemaField(
|
||||||
description="SMTP credentials",
|
description="SMTP Config",
|
||||||
default=EmailCredentials(),
|
default=SmtpConfig(),
|
||||||
)
|
)
|
||||||
|
credentials: SMTPCredentialsInput = SMTPCredentialsField()
|
||||||
|
|
||||||
class Output(BlockSchema):
|
class Output(BlockSchema):
|
||||||
status: str = SchemaField(description="Status of the email sending operation")
|
status: str = SchemaField(description="Status of the email sending operation")
|
||||||
|
@ -43,7 +75,6 @@ class SendEmailBlock(Block):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
disabled=True,
|
|
||||||
id="4335878a-394e-4e67-adf2-919877ff49ae",
|
id="4335878a-394e-4e67-adf2-919877ff49ae",
|
||||||
description="This block sends an email using the provided SMTP credentials.",
|
description="This block sends an email using the provided SMTP credentials.",
|
||||||
categories={BlockCategory.OUTPUT},
|
categories={BlockCategory.OUTPUT},
|
||||||
|
@ -53,25 +84,29 @@ class SendEmailBlock(Block):
|
||||||
"to_email": "recipient@example.com",
|
"to_email": "recipient@example.com",
|
||||||
"subject": "Test Email",
|
"subject": "Test Email",
|
||||||
"body": "This is a test email.",
|
"body": "This is a test email.",
|
||||||
"creds": {
|
"config": {
|
||||||
"smtp_server": "smtp.gmail.com",
|
"smtp_server": "smtp.gmail.com",
|
||||||
"smtp_port": 25,
|
"smtp_port": 25,
|
||||||
"smtp_username": "your-email@gmail.com",
|
|
||||||
"smtp_password": "your-gmail-password",
|
|
||||||
},
|
},
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
},
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_output=[("status", "Email sent successfully")],
|
test_output=[("status", "Email sent successfully")],
|
||||||
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def send_email(
|
def send_email(
|
||||||
creds: EmailCredentials, to_email: str, subject: str, body: str
|
config: SmtpConfig,
|
||||||
|
to_email: str,
|
||||||
|
subject: str,
|
||||||
|
body: str,
|
||||||
|
credentials: SMTPCredentials,
|
||||||
) -> str:
|
) -> str:
|
||||||
smtp_server = creds.smtp_server
|
smtp_server = config.smtp_server
|
||||||
smtp_port = creds.smtp_port
|
smtp_port = config.smtp_port
|
||||||
smtp_username = creds.smtp_username.get_secret_value()
|
smtp_username = credentials.username.get_secret_value()
|
||||||
smtp_password = creds.smtp_password.get_secret_value()
|
smtp_password = credentials.password.get_secret_value()
|
||||||
|
|
||||||
msg = MIMEMultipart()
|
msg = MIMEMultipart()
|
||||||
msg["From"] = smtp_username
|
msg["From"] = smtp_username
|
||||||
|
@ -86,10 +121,13 @@ class SendEmailBlock(Block):
|
||||||
|
|
||||||
return "Email sent successfully"
|
return "Email sent successfully"
|
||||||
|
|
||||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
yield "status", self.send_email(
|
yield "status", self.send_email(
|
||||||
input_data.creds,
|
config=input_data.config,
|
||||||
input_data.to_email,
|
to_email=input_data.to_email,
|
||||||
input_data.subject,
|
subject=input_data.subject,
|
||||||
input_data.body,
|
body=input_data.body,
|
||||||
|
credentials=credentials,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,22 +1,48 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Iterator
|
from typing import Iterator, Literal
|
||||||
|
|
||||||
import praw
|
import praw
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, SecretStr
|
||||||
|
|
||||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
from backend.data.model import BlockSecret, SchemaField, SecretField
|
from backend.data.model import (
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
SchemaField,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.mock import MockObject
|
from backend.util.mock import MockObject
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
RedditCredentials = UserPasswordCredentials
|
||||||
|
RedditCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.REDDIT],
|
||||||
|
Literal["user_password"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RedditCredentials(BaseModel):
|
def RedditCredentialsField() -> RedditCredentialsInput:
|
||||||
client_id: BlockSecret = SecretField(key="reddit_client_id")
|
"""Creates a Reddit credentials input on a block."""
|
||||||
client_secret: BlockSecret = SecretField(key="reddit_client_secret")
|
return CredentialsField(
|
||||||
username: BlockSecret = SecretField(key="reddit_username")
|
description="The Reddit integration requires a username and password.",
|
||||||
password: BlockSecret = SecretField(key="reddit_password")
|
)
|
||||||
user_agent: str = "AutoGPT:1.0 (by /u/autogpt)"
|
|
||||||
|
|
||||||
model_config = ConfigDict(title="Reddit Credentials")
|
|
||||||
|
TEST_CREDENTIALS = UserPasswordCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="reddit",
|
||||||
|
username=SecretStr("mock-reddit-username"),
|
||||||
|
password=SecretStr("mock-reddit-password"),
|
||||||
|
title="Mock Reddit credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class RedditPost(BaseModel):
|
class RedditPost(BaseModel):
|
||||||
|
@ -31,13 +57,16 @@ class RedditComment(BaseModel):
|
||||||
comment: str
|
comment: str
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||||
client = praw.Reddit(
|
client = praw.Reddit(
|
||||||
client_id=creds.client_id.get_secret_value(),
|
client_id=settings.secrets.reddit_client_id,
|
||||||
client_secret=creds.client_secret.get_secret_value(),
|
client_secret=settings.secrets.reddit_client_secret,
|
||||||
username=creds.username.get_secret_value(),
|
username=creds.username.get_secret_value(),
|
||||||
password=creds.password.get_secret_value(),
|
password=creds.password.get_secret_value(),
|
||||||
user_agent=creds.user_agent,
|
user_agent=settings.config.reddit_user_agent,
|
||||||
)
|
)
|
||||||
me = client.user.me()
|
me = client.user.me()
|
||||||
if not me:
|
if not me:
|
||||||
|
@ -48,11 +77,11 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
|
||||||
|
|
||||||
class GetRedditPostsBlock(Block):
|
class GetRedditPostsBlock(Block):
|
||||||
class Input(BlockSchema):
|
class Input(BlockSchema):
|
||||||
subreddit: str = SchemaField(description="Subreddit name")
|
subreddit: str = SchemaField(
|
||||||
creds: RedditCredentials = SchemaField(
|
description="Subreddit name, excluding the /r/ prefix",
|
||||||
description="Reddit credentials",
|
default="writingprompts",
|
||||||
default=RedditCredentials(),
|
|
||||||
)
|
)
|
||||||
|
credentials: RedditCredentialsInput = RedditCredentialsField()
|
||||||
last_minutes: int | None = SchemaField(
|
last_minutes: int | None = SchemaField(
|
||||||
description="Post time to stop minutes ago while fetching posts",
|
description="Post time to stop minutes ago while fetching posts",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -70,20 +99,18 @@ class GetRedditPostsBlock(Block):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
disabled=True,
|
|
||||||
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
|
id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
|
||||||
description="This block fetches Reddit posts from a defined subreddit name.",
|
description="This block fetches Reddit posts from a defined subreddit name.",
|
||||||
categories={BlockCategory.SOCIAL},
|
categories={BlockCategory.SOCIAL},
|
||||||
|
disabled=(
|
||||||
|
not settings.secrets.reddit_client_id
|
||||||
|
or not settings.secrets.reddit_client_secret
|
||||||
|
),
|
||||||
input_schema=GetRedditPostsBlock.Input,
|
input_schema=GetRedditPostsBlock.Input,
|
||||||
output_schema=GetRedditPostsBlock.Output,
|
output_schema=GetRedditPostsBlock.Output,
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
test_input={
|
test_input={
|
||||||
"creds": {
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
"client_id": "client_id",
|
|
||||||
"client_secret": "client_secret",
|
|
||||||
"username": "username",
|
|
||||||
"password": "password",
|
|
||||||
"user_agent": "user_agent",
|
|
||||||
},
|
|
||||||
"subreddit": "subreddit",
|
"subreddit": "subreddit",
|
||||||
"last_post": "id3",
|
"last_post": "id3",
|
||||||
"post_limit": 2,
|
"post_limit": 2,
|
||||||
|
@ -103,7 +130,7 @@ class GetRedditPostsBlock(Block):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
test_mock={
|
test_mock={
|
||||||
"get_posts": lambda _: [
|
"get_posts": lambda input_data, credentials: [
|
||||||
MockObject(id="id1", title="title1", selftext="body1"),
|
MockObject(id="id1", title="title1", selftext="body1"),
|
||||||
MockObject(id="id2", title="title2", selftext="body2"),
|
MockObject(id="id2", title="title2", selftext="body2"),
|
||||||
MockObject(id="id3", title="title2", selftext="body2"),
|
MockObject(id="id3", title="title2", selftext="body2"),
|
||||||
|
@ -112,14 +139,18 @@ class GetRedditPostsBlock(Block):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_posts(input_data: Input) -> Iterator[praw.reddit.Submission]:
|
def get_posts(
|
||||||
client = get_praw(input_data.creds)
|
input_data: Input, *, credentials: RedditCredentials
|
||||||
|
) -> Iterator[praw.reddit.Submission]:
|
||||||
|
client = get_praw(credentials)
|
||||||
subreddit = client.subreddit(input_data.subreddit)
|
subreddit = client.subreddit(input_data.subreddit)
|
||||||
return subreddit.new(limit=input_data.post_limit or 10)
|
return subreddit.new(limit=input_data.post_limit or 10)
|
||||||
|
|
||||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
def run(
|
||||||
|
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
current_time = datetime.now(tz=timezone.utc)
|
current_time = datetime.now(tz=timezone.utc)
|
||||||
for post in self.get_posts(input_data):
|
for post in self.get_posts(input_data=input_data, credentials=credentials):
|
||||||
if input_data.last_minutes:
|
if input_data.last_minutes:
|
||||||
post_datetime = datetime.fromtimestamp(
|
post_datetime = datetime.fromtimestamp(
|
||||||
post.created_utc, tz=timezone.utc
|
post.created_utc, tz=timezone.utc
|
||||||
|
@ -141,9 +172,7 @@ class GetRedditPostsBlock(Block):
|
||||||
|
|
||||||
class PostRedditCommentBlock(Block):
|
class PostRedditCommentBlock(Block):
|
||||||
class Input(BlockSchema):
|
class Input(BlockSchema):
|
||||||
creds: RedditCredentials = SchemaField(
|
credentials: RedditCredentialsInput = RedditCredentialsField()
|
||||||
description="Reddit credentials", default=RedditCredentials()
|
|
||||||
)
|
|
||||||
data: RedditComment = SchemaField(description="Reddit comment")
|
data: RedditComment = SchemaField(description="Reddit comment")
|
||||||
|
|
||||||
class Output(BlockSchema):
|
class Output(BlockSchema):
|
||||||
|
@ -156,7 +185,15 @@ class PostRedditCommentBlock(Block):
|
||||||
categories={BlockCategory.SOCIAL},
|
categories={BlockCategory.SOCIAL},
|
||||||
input_schema=PostRedditCommentBlock.Input,
|
input_schema=PostRedditCommentBlock.Input,
|
||||||
output_schema=PostRedditCommentBlock.Output,
|
output_schema=PostRedditCommentBlock.Output,
|
||||||
test_input={"data": {"post_id": "id", "comment": "comment"}},
|
disabled=(
|
||||||
|
not settings.secrets.reddit_client_id
|
||||||
|
or not settings.secrets.reddit_client_secret
|
||||||
|
),
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
test_input={
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"data": {"post_id": "id", "comment": "comment"},
|
||||||
|
},
|
||||||
test_output=[("comment_id", "dummy_comment_id")],
|
test_output=[("comment_id", "dummy_comment_id")],
|
||||||
test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"},
|
test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"},
|
||||||
)
|
)
|
||||||
|
@ -170,5 +207,7 @@ class PostRedditCommentBlock(Block):
|
||||||
raise ValueError("Failed to post comment.")
|
raise ValueError("Failed to post comment.")
|
||||||
return new_comment.id
|
return new_comment.id
|
||||||
|
|
||||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
def run(
|
||||||
yield "comment_id", self.reply_post(input_data.creds, input_data.data)
|
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
|
||||||
|
) -> BlockOutput:
|
||||||
|
yield "comment_id", self.reply_post(credentials, input_data.data)
|
||||||
|
|
|
@ -136,6 +136,7 @@ async def create_graph_execution(
|
||||||
graph_version: int,
|
graph_version: int,
|
||||||
nodes_input: list[tuple[str, BlockInput]],
|
nodes_input: list[tuple[str, BlockInput]],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
preset_id: str | None = None,
|
||||||
) -> tuple[str, list[ExecutionResult]]:
|
) -> tuple[str, list[ExecutionResult]]:
|
||||||
"""
|
"""
|
||||||
Create a new AgentGraphExecution record.
|
Create a new AgentGraphExecution record.
|
||||||
|
@ -163,6 +164,7 @@ async def create_graph_execution(
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
|
"agentPresetId": preset_id,
|
||||||
},
|
},
|
||||||
include=GRAPH_EXECUTION_INCLUDE,
|
include=GRAPH_EXECUTION_INCLUDE,
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,7 +6,13 @@ from datetime import datetime, timezone
|
||||||
from typing import Any, Literal, Optional, Type
|
from typing import Any, Literal, Optional, Type
|
||||||
|
|
||||||
import prisma
|
import prisma
|
||||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
from prisma.models import (
|
||||||
|
AgentGraph,
|
||||||
|
AgentGraphExecution,
|
||||||
|
AgentNode,
|
||||||
|
AgentNodeLink,
|
||||||
|
StoreListingVersion,
|
||||||
|
)
|
||||||
from prisma.types import AgentGraphWhereInput
|
from prisma.types import AgentGraphWhereInput
|
||||||
from pydantic.fields import computed_field
|
from pydantic.fields import computed_field
|
||||||
|
|
||||||
|
@ -529,7 +535,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
|
||||||
async def get_graph(
|
async def get_graph(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
version: int | None = None,
|
version: int | None = None,
|
||||||
template: bool = False,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
for_export: bool = False,
|
for_export: bool = False,
|
||||||
) -> GraphModel | None:
|
) -> GraphModel | None:
|
||||||
|
@ -543,21 +548,36 @@ async def get_graph(
|
||||||
where_clause: AgentGraphWhereInput = {
|
where_clause: AgentGraphWhereInput = {
|
||||||
"id": graph_id,
|
"id": graph_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if version is not None:
|
if version is not None:
|
||||||
where_clause["version"] = version
|
where_clause["version"] = version
|
||||||
elif not template:
|
else:
|
||||||
where_clause["isActive"] = True
|
where_clause["isActive"] = True
|
||||||
|
|
||||||
# TODO: Fix hack workaround to get adding store agents to work
|
|
||||||
if user_id is not None and not template:
|
|
||||||
where_clause["userId"] = user_id
|
|
||||||
|
|
||||||
graph = await AgentGraph.prisma().find_first(
|
graph = await AgentGraph.prisma().find_first(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
include=AGENT_GRAPH_INCLUDE,
|
include=AGENT_GRAPH_INCLUDE,
|
||||||
order={"version": "desc"},
|
order={"version": "desc"},
|
||||||
)
|
)
|
||||||
return GraphModel.from_db(graph, for_export) if graph else None
|
|
||||||
|
# The Graph has to be owned by the user or a store listing.
|
||||||
|
if (
|
||||||
|
graph is None
|
||||||
|
or graph.userId != user_id
|
||||||
|
and not (
|
||||||
|
await StoreListingVersion.prisma().find_first(
|
||||||
|
where=prisma.types.StoreListingVersionWhereInput(
|
||||||
|
agentId=graph_id,
|
||||||
|
agentVersion=version or graph.version,
|
||||||
|
isDeleted=False,
|
||||||
|
StoreListing={"is": {"isApproved": True}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return GraphModel.from_db(graph, for_export)
|
||||||
|
|
||||||
|
|
||||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||||
|
@ -611,9 +631,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
|
||||||
async with transaction() as tx:
|
async with transaction() as tx:
|
||||||
await __create_graph(tx, graph, user_id)
|
await __create_graph(tx, graph, user_id)
|
||||||
|
|
||||||
if created_graph := await get_graph(
|
if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
|
||||||
graph.id, graph.version, graph.is_template, user_id=user_id
|
|
||||||
):
|
|
||||||
return created_graph
|
return created_graph
|
||||||
|
|
||||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
@ -206,20 +207,35 @@ class OAuth2Credentials(_BaseCredentials):
|
||||||
class APIKeyCredentials(_BaseCredentials):
|
class APIKeyCredentials(_BaseCredentials):
|
||||||
type: Literal["api_key"] = "api_key"
|
type: Literal["api_key"] = "api_key"
|
||||||
api_key: SecretStr
|
api_key: SecretStr
|
||||||
expires_at: Optional[int]
|
expires_at: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Unix timestamp (seconds) indicating when the API key expires (if at all)",
|
||||||
|
)
|
||||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||||
|
|
||||||
def bearer(self) -> str:
|
def bearer(self) -> str:
|
||||||
return f"Bearer {self.api_key.get_secret_value()}"
|
return f"Bearer {self.api_key.get_secret_value()}"
|
||||||
|
|
||||||
|
|
||||||
|
class UserPasswordCredentials(_BaseCredentials):
|
||||||
|
type: Literal["user_password"] = "user_password"
|
||||||
|
username: SecretStr
|
||||||
|
password: SecretStr
|
||||||
|
|
||||||
|
def bearer(self) -> str:
|
||||||
|
# Converting the string to bytes using encode()
|
||||||
|
# Base64 encoding it with base64.b64encode()
|
||||||
|
# Converting the resulting bytes back to a string with decode()
|
||||||
|
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
|
||||||
|
|
||||||
|
|
||||||
Credentials = Annotated[
|
Credentials = Annotated[
|
||||||
OAuth2Credentials | APIKeyCredentials,
|
OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
CredentialsType = Literal["api_key", "oauth2"]
|
CredentialsType = Literal["api_key", "oauth2", "user_password"]
|
||||||
|
|
||||||
|
|
||||||
class OAuthState(BaseModel):
|
class OAuthState(BaseModel):
|
||||||
|
|
|
@ -780,7 +780,8 @@ class ExecutionManager(AppService):
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
data: BlockInput,
|
data: BlockInput,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph_version: int | None = None,
|
graph_version: int,
|
||||||
|
preset_id: str | None = None,
|
||||||
) -> GraphExecutionEntry:
|
) -> GraphExecutionEntry:
|
||||||
graph: GraphModel | None = self.db_client.get_graph(
|
graph: GraphModel | None = self.db_client.get_graph(
|
||||||
graph_id=graph_id, user_id=user_id, version=graph_version
|
graph_id=graph_id, user_id=user_id, version=graph_version
|
||||||
|
@ -829,6 +830,7 @@ class ExecutionManager(AppService):
|
||||||
graph_version=graph.version,
|
graph_version=graph.version,
|
||||||
nodes_input=nodes_input,
|
nodes_input=nodes_input,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
preset_id=preset_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
starting_node_execs = []
|
starting_node_execs = []
|
||||||
|
|
|
@ -63,7 +63,10 @@ def execute_graph(**kwargs):
|
||||||
try:
|
try:
|
||||||
log(f"Executing recurring job for graph #{args.graph_id}")
|
log(f"Executing recurring job for graph #{args.graph_id}")
|
||||||
get_execution_client().add_execution(
|
get_execution_client().add_execution(
|
||||||
args.graph_id, args.input_data, args.user_id
|
graph_id=args.graph_id,
|
||||||
|
data=args.input_data,
|
||||||
|
user_id=args.user_id,
|
||||||
|
graph_version=args.graph_version,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
logger.exception(f"Error executing graph {args.graph_id}: {e}")
|
||||||
|
|
|
@ -25,9 +25,11 @@ class ProviderName(str, Enum):
|
||||||
OPENWEATHERMAP = "openweathermap"
|
OPENWEATHERMAP = "openweathermap"
|
||||||
OPEN_ROUTER = "open_router"
|
OPEN_ROUTER = "open_router"
|
||||||
PINECONE = "pinecone"
|
PINECONE = "pinecone"
|
||||||
|
REDDIT = "reddit"
|
||||||
REPLICATE = "replicate"
|
REPLICATE = "replicate"
|
||||||
REVID = "revid"
|
REVID = "revid"
|
||||||
SLANT3D = "slant3d"
|
SLANT3D = "slant3d"
|
||||||
|
SMTP = "smtp"
|
||||||
TWITTER = "twitter"
|
TWITTER = "twitter"
|
||||||
UNREAL_SPEECH = "unreal_speech"
|
UNREAL_SPEECH = "unreal_speech"
|
||||||
# --8<-- [end:ProviderName]
|
# --8<-- [end:ProviderName]
|
||||||
|
|
|
@ -168,7 +168,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||||
|
|
||||||
id = str(uuid4())
|
id = str(uuid4())
|
||||||
secret = secrets.token_hex(32)
|
secret = secrets.token_hex(32)
|
||||||
provider_name = self.PROVIDER_NAME
|
provider_name: ProviderName = self.PROVIDER_NAME
|
||||||
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
|
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
|
||||||
if register:
|
if register:
|
||||||
if not credentials:
|
if not credentials:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from backend.data import integrations
|
from backend.data import integrations
|
||||||
from backend.data.model import APIKeyCredentials, Credentials, OAuth2Credentials
|
from backend.data.model import Credentials
|
||||||
|
|
||||||
from ._base import WT, BaseWebhooksManager
|
from ._base import WT, BaseWebhooksManager
|
||||||
|
|
||||||
|
@ -25,6 +25,6 @@ class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
|
||||||
async def _deregister_webhook(
|
async def _deregister_webhook(
|
||||||
self,
|
self,
|
||||||
webhook: integrations.Webhook,
|
webhook: integrations.Webhook,
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: Credentials,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
from typing import TYPE_CHECKING, Annotated, Literal
|
from typing import TYPE_CHECKING, Annotated, Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.data.graph import set_node_webhook
|
from backend.data.graph import set_node_webhook
|
||||||
from backend.data.integrations import (
|
from backend.data.integrations import (
|
||||||
|
@ -12,12 +12,7 @@ from backend.data.integrations import (
|
||||||
publish_webhook_event,
|
publish_webhook_event,
|
||||||
wait_for_webhook_event,
|
wait_for_webhook_event,
|
||||||
)
|
)
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
|
||||||
APIKeyCredentials,
|
|
||||||
Credentials,
|
|
||||||
CredentialsType,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.executor.manager import ExecutionManager
|
from backend.executor.manager import ExecutionManager
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.oauth import HANDLERS_BY_NAME
|
from backend.integrations.oauth import HANDLERS_BY_NAME
|
||||||
|
@ -199,22 +194,15 @@ def get_credential(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/credentials", status_code=201)
|
@router.post("/{provider}/credentials", status_code=201)
|
||||||
def create_api_key_credentials(
|
def create_credentials(
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
provider: Annotated[
|
provider: Annotated[
|
||||||
ProviderName, Path(title="The provider to create credentials for")
|
ProviderName, Path(title="The provider to create credentials for")
|
||||||
],
|
],
|
||||||
api_key: Annotated[str, Body(title="The API key to store")],
|
credential: Credentials,
|
||||||
title: Annotated[str, Body(title="Optional title for the credentials")],
|
) -> Credentials:
|
||||||
expires_at: Annotated[
|
new_credentials = credential.__class__(
|
||||||
int | None, Body(title="Unix timestamp when the key expires")
|
provider=provider, **credential.model_dump(exclude={"provider"})
|
||||||
] = None,
|
|
||||||
) -> APIKeyCredentials:
|
|
||||||
new_credentials = APIKeyCredentials(
|
|
||||||
provider=provider,
|
|
||||||
api_key=SecretStr(api_key),
|
|
||||||
title=title,
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -320,7 +308,8 @@ async def webhook_ingress_generic(
|
||||||
continue
|
continue
|
||||||
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
|
||||||
executor.add_execution(
|
executor.add_execution(
|
||||||
node.graph_id,
|
graph_id=node.graph_id,
|
||||||
|
graph_version=node.graph_version,
|
||||||
data={f"webhook_{webhook_id}_payload": payload},
|
data={f"webhook_{webhook_id}_payload": payload},
|
||||||
user_id=webhook.user_id,
|
user_id=webhook.user_id,
|
||||||
)
|
)
|
||||||
|
|
|
@ -56,3 +56,18 @@ class SetGraphActiveVersion(pydantic.BaseModel):
|
||||||
|
|
||||||
class UpdatePermissionsRequest(pydantic.BaseModel):
|
class UpdatePermissionsRequest(pydantic.BaseModel):
|
||||||
permissions: List[APIKeyPermission]
|
permissions: List[APIKeyPermission]
|
||||||
|
|
||||||
|
|
||||||
|
class Pagination(pydantic.BaseModel):
|
||||||
|
total_items: int = pydantic.Field(
|
||||||
|
description="Total number of items.", examples=[42]
|
||||||
|
)
|
||||||
|
total_pages: int = pydantic.Field(
|
||||||
|
description="Total number of pages.", examples=[2]
|
||||||
|
)
|
||||||
|
current_page: int = pydantic.Field(
|
||||||
|
description="Current_page page number.", examples=[1]
|
||||||
|
)
|
||||||
|
page_size: int = pydantic.Field(
|
||||||
|
description="Number of items per page.", examples=[25]
|
||||||
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import contextlib
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
import autogpt_libs.auth.models
|
||||||
import fastapi
|
import fastapi
|
||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import starlette.middleware.cors
|
import starlette.middleware.cors
|
||||||
|
@ -16,7 +17,9 @@ import backend.data.db
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.data.user
|
import backend.data.user
|
||||||
import backend.server.routers.v1
|
import backend.server.routers.v1
|
||||||
|
import backend.server.v2.library.model
|
||||||
import backend.server.v2.library.routes
|
import backend.server.v2.library.routes
|
||||||
|
import backend.server.v2.store.model
|
||||||
import backend.server.v2.store.routes
|
import backend.server.v2.store.routes
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
|
@ -117,9 +120,24 @@ class AgentServer(backend.util.service.AppProcess):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def test_execute_graph(
|
async def test_execute_graph(
|
||||||
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
node_input: dict[typing.Any, typing.Any],
|
||||||
|
user_id: str,
|
||||||
):
|
):
|
||||||
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id)
|
return backend.server.routers.v1.execute_graph(
|
||||||
|
graph_id, graph_version, node_input, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_get_graph(
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
|
return await backend.server.routers.v1.get_graph(
|
||||||
|
graph_id, user_id, graph_version
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def test_create_graph(
|
async def test_create_graph(
|
||||||
|
@ -149,5 +167,71 @@ class AgentServer(backend.util.service.AppProcess):
|
||||||
async def test_delete_graph(graph_id: str, user_id: str):
|
async def test_delete_graph(graph_id: str, user_id: str):
|
||||||
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
|
return await backend.server.routers.v1.delete_graph(graph_id, user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
|
||||||
|
return await backend.server.v2.library.routes.presets.get_presets(
|
||||||
|
user_id=user_id, page=page, page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_get_preset(preset_id: str, user_id: str):
|
||||||
|
return await backend.server.v2.library.routes.presets.get_preset(
|
||||||
|
preset_id=preset_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_create_preset(
|
||||||
|
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
|
return await backend.server.v2.library.routes.presets.create_preset(
|
||||||
|
preset=preset, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_update_preset(
|
||||||
|
preset_id: str,
|
||||||
|
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
|
return await backend.server.v2.library.routes.presets.update_preset(
|
||||||
|
preset_id=preset_id, preset=preset, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_delete_preset(preset_id: str, user_id: str):
|
||||||
|
return await backend.server.v2.library.routes.presets.delete_preset(
|
||||||
|
preset_id=preset_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_execute_preset(
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
preset_id: str,
|
||||||
|
node_input: dict[typing.Any, typing.Any],
|
||||||
|
user_id: str,
|
||||||
|
):
|
||||||
|
return await backend.server.v2.library.routes.presets.execute_preset(
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
preset_id=preset_id,
|
||||||
|
node_input=node_input,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_create_store_listing(
|
||||||
|
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
|
||||||
|
):
|
||||||
|
return await backend.server.v2.store.routes.create_submission(request, user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def test_review_store_listing(
|
||||||
|
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||||
|
user: autogpt_libs.auth.models.User,
|
||||||
|
):
|
||||||
|
return await backend.server.v2.store.routes.review_submission(request, user)
|
||||||
|
|
||||||
def set_test_dependency_overrides(self, overrides: dict):
|
def set_test_dependency_overrides(self, overrides: dict):
|
||||||
app.dependency_overrides.update(overrides)
|
app.dependency_overrides.update(overrides)
|
||||||
|
|
|
@ -13,6 +13,7 @@ from typing_extensions import Optional, TypedDict
|
||||||
import backend.data.block
|
import backend.data.block
|
||||||
import backend.server.integrations.router
|
import backend.server.integrations.router
|
||||||
import backend.server.routers.analytics
|
import backend.server.routers.analytics
|
||||||
|
import backend.server.v2.library.db
|
||||||
from backend.data import execution as execution_db
|
from backend.data import execution as execution_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.api_key import (
|
from backend.data.api_key import (
|
||||||
|
@ -180,11 +181,6 @@ async def get_graph(
|
||||||
tags=["graphs"],
|
tags=["graphs"],
|
||||||
dependencies=[Depends(auth_middleware)],
|
dependencies=[Depends(auth_middleware)],
|
||||||
)
|
)
|
||||||
@v1_router.get(
|
|
||||||
path="/templates/{graph_id}/versions",
|
|
||||||
tags=["templates", "graphs"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
)
|
|
||||||
async def get_graph_all_versions(
|
async def get_graph_all_versions(
|
||||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||||
) -> Sequence[graph_db.GraphModel]:
|
) -> Sequence[graph_db.GraphModel]:
|
||||||
|
@ -200,12 +196,11 @@ async def get_graph_all_versions(
|
||||||
async def create_new_graph(
|
async def create_new_graph(
|
||||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
|
return await do_create_graph(create_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
async def do_create_graph(
|
async def do_create_graph(
|
||||||
create_graph: CreateGraph,
|
create_graph: CreateGraph,
|
||||||
is_template: bool,
|
|
||||||
# user_id doesn't have to be annotated like on other endpoints,
|
# user_id doesn't have to be annotated like on other endpoints,
|
||||||
# because create_graph isn't used directly as an endpoint
|
# because create_graph isn't used directly as an endpoint
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -217,7 +212,6 @@ async def do_create_graph(
|
||||||
graph = await graph_db.get_graph(
|
graph = await graph_db.get_graph(
|
||||||
create_graph.template_id,
|
create_graph.template_id,
|
||||||
create_graph.template_version,
|
create_graph.template_version,
|
||||||
template=True,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
if not graph:
|
if not graph:
|
||||||
|
@ -225,13 +219,18 @@ async def do_create_graph(
|
||||||
400, detail=f"Template #{create_graph.template_id} not found"
|
400, detail=f"Template #{create_graph.template_id} not found"
|
||||||
)
|
)
|
||||||
graph.version = 1
|
graph.version = 1
|
||||||
|
|
||||||
|
# Create a library agent for the new graph
|
||||||
|
await backend.server.v2.library.db.create_library_agent(
|
||||||
|
graph.id,
|
||||||
|
graph.version,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Either graph or template_id must be provided."
|
status_code=400, detail="Either graph or template_id must be provided."
|
||||||
)
|
)
|
||||||
|
|
||||||
graph.is_template = is_template
|
|
||||||
graph.is_active = not is_template
|
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||||
|
|
||||||
graph = await graph_db.create_graph(graph, user_id=user_id)
|
graph = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
@ -261,11 +260,6 @@ async def delete_graph(
|
||||||
@v1_router.put(
|
@v1_router.put(
|
||||||
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
|
||||||
)
|
)
|
||||||
@v1_router.put(
|
|
||||||
path="/templates/{graph_id}",
|
|
||||||
tags=["templates", "graphs"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
)
|
|
||||||
async def update_graph(
|
async def update_graph(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
|
@ -298,6 +292,11 @@ async def update_graph(
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
|
|
||||||
|
# Keep the library agent up to date with the new active version
|
||||||
|
await backend.server.v2.library.db.update_agent_version_in_library(
|
||||||
|
user_id, graph.id, graph.version
|
||||||
|
)
|
||||||
|
|
||||||
def get_credentials(credentials_id: str) -> "Credentials | None":
|
def get_credentials(credentials_id: str) -> "Credentials | None":
|
||||||
return integration_creds_manager.get(user_id, credentials_id)
|
return integration_creds_manager.get(user_id, credentials_id)
|
||||||
|
|
||||||
|
@ -353,6 +352,12 @@ async def set_graph_active_version(
|
||||||
version=new_active_version,
|
version=new_active_version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Keep the library agent up to date with the new active version
|
||||||
|
await backend.server.v2.library.db.update_agent_version_in_library(
|
||||||
|
user_id, new_active_graph.id, new_active_graph.version
|
||||||
|
)
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(
|
await on_graph_deactivate(
|
||||||
|
@ -368,12 +373,13 @@ async def set_graph_active_version(
|
||||||
)
|
)
|
||||||
def execute_graph(
|
def execute_graph(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
node_input: dict[Any, Any],
|
node_input: dict[Any, Any],
|
||||||
user_id: Annotated[str, Depends(get_user_id)],
|
user_id: Annotated[str, Depends(get_user_id)],
|
||||||
) -> dict[str, Any]: # FIXME: add proper return type
|
) -> dict[str, Any]: # FIXME: add proper return type
|
||||||
try:
|
try:
|
||||||
graph_exec = execution_manager_client().add_execution(
|
graph_exec = execution_manager_client().add_execution(
|
||||||
graph_id, node_input, user_id=user_id
|
graph_id, node_input, user_id=user_id, graph_version=graph_version
|
||||||
)
|
)
|
||||||
return {"id": graph_exec.graph_exec_id}
|
return {"id": graph_exec.graph_exec_id}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -428,47 +434,6 @@ async def get_graph_run_node_execution_results(
|
||||||
return await execution_db.get_execution_results(graph_exec_id)
|
return await execution_db.get_execution_results(graph_exec_id)
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
|
||||||
##################### Templates ########################
|
|
||||||
########################################################
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
|
||||||
path="/templates",
|
|
||||||
tags=["graphs", "templates"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
)
|
|
||||||
async def get_templates(
|
|
||||||
user_id: Annotated[str, Depends(get_user_id)]
|
|
||||||
) -> Sequence[graph_db.GraphModel]:
|
|
||||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.get(
|
|
||||||
path="/templates/{graph_id}",
|
|
||||||
tags=["templates", "graphs"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
)
|
|
||||||
async def get_template(
|
|
||||||
graph_id: str, version: int | None = None
|
|
||||||
) -> graph_db.GraphModel:
|
|
||||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
|
||||||
if not graph:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
@v1_router.post(
|
|
||||||
path="/templates",
|
|
||||||
tags=["templates", "graphs"],
|
|
||||||
dependencies=[Depends(auth_middleware)],
|
|
||||||
)
|
|
||||||
async def create_new_template(
|
|
||||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
|
||||||
) -> graph_db.GraphModel:
|
|
||||||
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
|
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
##################### Schedules ########################
|
##################### Schedules ########################
|
||||||
########################################################
|
########################################################
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import prisma.errors
|
import prisma.errors
|
||||||
import prisma.models
|
import prisma.models
|
||||||
|
@ -7,6 +7,7 @@ import prisma.types
|
||||||
|
|
||||||
import backend.data.graph
|
import backend.data.graph
|
||||||
import backend.data.includes
|
import backend.data.includes
|
||||||
|
import backend.server.model
|
||||||
import backend.server.v2.library.model
|
import backend.server.v2.library.model
|
||||||
import backend.server.v2.store.exceptions
|
import backend.server.v2.store.exceptions
|
||||||
|
|
||||||
|
@ -14,90 +15,152 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_library_agents(
|
async def get_library_agents(
|
||||||
user_id: str,
|
user_id: str, search_query: str | None = None
|
||||||
) -> List[backend.server.v2.library.model.LibraryAgent]:
|
) -> list[backend.server.v2.library.model.LibraryAgent]:
|
||||||
"""
|
logger.debug(
|
||||||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
|
f"Fetching library agents for user_id={user_id} search_query={search_query}"
|
||||||
"""
|
)
|
||||||
logger.debug(f"Getting library agents for user {user_id}")
|
|
||||||
|
|
||||||
try:
|
if search_query and len(search_query.strip()) > 100:
|
||||||
# Get agents created by user with nodes and links
|
logger.warning(f"Search query too long: {search_query}")
|
||||||
user_created = await prisma.models.AgentGraph.prisma().find_many(
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
|
"Search query is too long."
|
||||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get agents in user's library with nodes and links
|
where_clause = prisma.types.LibraryAgentWhereInput(
|
||||||
library_agents = await prisma.models.UserAgent.prisma().find_many(
|
userId=user_id,
|
||||||
where=prisma.types.UserAgentWhereInput(
|
isDeleted=False,
|
||||||
userId=user_id, isDeleted=False, isArchived=False
|
isArchived=False,
|
||||||
),
|
)
|
||||||
include={
|
|
||||||
|
if search_query:
|
||||||
|
where_clause["OR"] = [
|
||||||
|
{
|
||||||
"Agent": {
|
"Agent": {
|
||||||
"include": {
|
"is": {"name": {"contains": search_query, "mode": "insensitive"}}
|
||||||
"AgentNodes": {
|
}
|
||||||
"include": {
|
},
|
||||||
"Input": True,
|
{
|
||||||
"Output": True,
|
"Agent": {
|
||||||
"Webhook": True,
|
"is": {
|
||||||
"AgentBlock": True,
|
"description": {"contains": search_query, "mode": "insensitive"}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||||
|
where=where_clause,
|
||||||
|
include={
|
||||||
|
"Agent": {
|
||||||
|
"include": {
|
||||||
|
"AgentNodes": {"include": {"Input": True, "Output": True}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
order=[{"updatedAt": "desc"}],
|
||||||
|
)
|
||||||
|
logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.")
|
||||||
|
return [
|
||||||
|
backend.server.v2.library.model.LibraryAgent.from_db(agent)
|
||||||
|
for agent in library_agents
|
||||||
|
]
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error fetching library agents: {e}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Unable to fetch library agents."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to Graph models first
|
|
||||||
graphs = []
|
|
||||||
|
|
||||||
# Add user created agents
|
async def create_library_agent(
|
||||||
for agent in user_created:
|
agent_id: str, agent_version: int, user_id: str
|
||||||
try:
|
) -> prisma.models.LibraryAgent:
|
||||||
graphs.append(backend.data.graph.GraphModel.from_db(agent))
|
"""
|
||||||
except Exception as e:
|
Adds an agent to the user's library (LibraryAgent table)
|
||||||
logger.error(f"Error processing user created agent {agent.id}: {e}")
|
"""
|
||||||
continue
|
|
||||||
|
|
||||||
# Add library agents
|
try:
|
||||||
for agent in library_agents:
|
|
||||||
if agent.Agent:
|
|
||||||
try:
|
|
||||||
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing library agent {agent.agentId}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Convert Graph models to LibraryAgent models
|
library_agent = await prisma.models.LibraryAgent.prisma().create(
|
||||||
result = []
|
data=prisma.types.LibraryAgentCreateInput(
|
||||||
for graph in graphs:
|
userId=user_id,
|
||||||
result.append(
|
agentId=agent_id,
|
||||||
backend.server.v2.library.model.LibraryAgent(
|
agentVersion=agent_version,
|
||||||
id=graph.id,
|
isCreatedByUser=False,
|
||||||
version=graph.version,
|
useGraphIsActiveVersion=True,
|
||||||
is_active=graph.is_active,
|
|
||||||
name=graph.name,
|
|
||||||
description=graph.description,
|
|
||||||
isCreatedByUser=any(a.id == graph.id for a in user_created),
|
|
||||||
input_schema=graph.input_schema,
|
|
||||||
output_schema=graph.output_schema,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
logger.debug(f"Found {len(result)} library agents")
|
return library_agent
|
||||||
return result
|
|
||||||
|
|
||||||
except prisma.errors.PrismaError as e:
|
except prisma.errors.PrismaError as e:
|
||||||
logger.error(f"Database error getting library agents: {str(e)}")
|
logger.error(f"Database error creating agent to library: {str(e)}")
|
||||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
"Failed to fetch library agents"
|
"Failed to create agent to library"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
|
async def update_agent_version_in_library(
|
||||||
|
user_id: str, agent_id: str, agent_version: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
|
Updates the agent version in the library
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await prisma.models.LibraryAgent.prisma().update(
|
||||||
|
where={
|
||||||
|
"userId": user_id,
|
||||||
|
"agentId": agent_id,
|
||||||
|
"useGraphIsActiveVersion": True,
|
||||||
|
},
|
||||||
|
data=prisma.types.LibraryAgentUpdateInput(
|
||||||
|
Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput(
|
||||||
|
connect=prisma.types.AgentGraphWhereUniqueInput(
|
||||||
|
id=agent_id,
|
||||||
|
version=agent_version,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error updating agent version in library: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to update agent version in library"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def update_library_agent(
|
||||||
|
library_agent_id: str,
|
||||||
|
user_id: str,
|
||||||
|
auto_update_version: bool = False,
|
||||||
|
is_favorite: bool = False,
|
||||||
|
is_archived: bool = False,
|
||||||
|
is_deleted: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Updates the library agent with the given fields
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await prisma.models.LibraryAgent.prisma().update(
|
||||||
|
where={"id": library_agent_id, "userId": user_id},
|
||||||
|
data=prisma.types.LibraryAgentUpdateInput(
|
||||||
|
useGraphIsActiveVersion=auto_update_version,
|
||||||
|
isFavorite=is_favorite,
|
||||||
|
isArchived=is_archived,
|
||||||
|
isDeleted=is_deleted,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error updating library agent: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to update library agent"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def add_store_agent_to_library(
|
||||||
|
store_listing_version_id: str, user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
|
||||||
if they don't already have it
|
if they don't already have it
|
||||||
"""
|
"""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -131,7 +194,7 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user already has this agent
|
# Check if user already has this agent
|
||||||
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
|
existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
|
||||||
where={
|
where={
|
||||||
"userId": user_id,
|
"userId": user_id,
|
||||||
"agentId": agent.id,
|
"agentId": agent.id,
|
||||||
|
@ -145,9 +208,9 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create UserAgent entry
|
# Create LibraryAgent entry
|
||||||
await prisma.models.UserAgent.prisma().create(
|
await prisma.models.LibraryAgent.prisma().create(
|
||||||
data=prisma.types.UserAgentCreateInput(
|
data=prisma.types.LibraryAgentCreateInput(
|
||||||
userId=user_id,
|
userId=user_id,
|
||||||
agentId=agent.id,
|
agentId=agent.id,
|
||||||
agentVersion=agent.version,
|
agentVersion=agent.version,
|
||||||
|
@ -163,3 +226,116 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
|
||||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
"Failed to add agent to library"
|
"Failed to add agent to library"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
########### Presets DB Functions #############
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
|
async def get_presets(
|
||||||
|
user_id: str, page: int, page_size: int
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
|
||||||
|
|
||||||
|
try:
|
||||||
|
presets = await prisma.models.AgentPreset.prisma().find_many(
|
||||||
|
where={"userId": user_id},
|
||||||
|
skip=page * page_size,
|
||||||
|
take=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_items = await prisma.models.AgentPreset.prisma().count(
|
||||||
|
where={"userId": user_id},
|
||||||
|
)
|
||||||
|
total_pages = (total_items + page_size - 1) // page_size
|
||||||
|
|
||||||
|
presets = [
|
||||||
|
backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
|
||||||
|
for preset in presets
|
||||||
|
]
|
||||||
|
|
||||||
|
return backend.server.v2.library.model.LibraryAgentPresetResponse(
|
||||||
|
presets=presets,
|
||||||
|
pagination=backend.server.model.Pagination(
|
||||||
|
total_items=total_items,
|
||||||
|
total_pages=total_pages,
|
||||||
|
current_page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error getting presets: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to fetch presets"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def get_preset(
|
||||||
|
user_id: str, preset_id: str
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPreset | None:
|
||||||
|
try:
|
||||||
|
preset = await prisma.models.AgentPreset.prisma().find_unique(
|
||||||
|
where={"id": preset_id, "userId": user_id}, include={"InputPresets": True}
|
||||||
|
)
|
||||||
|
if not preset:
|
||||||
|
return None
|
||||||
|
return backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error getting preset: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to fetch preset"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def create_or_update_preset(
|
||||||
|
user_id: str,
|
||||||
|
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||||
|
preset_id: str | None = None,
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||||
|
try:
|
||||||
|
new_preset = await prisma.models.AgentPreset.prisma().upsert(
|
||||||
|
where={
|
||||||
|
"id": preset_id if preset_id else "",
|
||||||
|
},
|
||||||
|
data={
|
||||||
|
"create": {
|
||||||
|
"userId": user_id,
|
||||||
|
"name": preset.name,
|
||||||
|
"description": preset.description,
|
||||||
|
"agentId": preset.agent_id,
|
||||||
|
"agentVersion": preset.agent_version,
|
||||||
|
"isActive": preset.is_active,
|
||||||
|
"InputPresets": {
|
||||||
|
"create": [
|
||||||
|
{"name": name, "data": json.dumps(data)}
|
||||||
|
for name, data in preset.inputs.items()
|
||||||
|
]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"update": {
|
||||||
|
"name": preset.name,
|
||||||
|
"description": preset.description,
|
||||||
|
"isActive": preset.is_active,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return backend.server.v2.library.model.LibraryAgentPreset.from_db(new_preset)
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error creating preset: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to create preset"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_preset(user_id: str, preset_id: str) -> None:
|
||||||
|
try:
|
||||||
|
await prisma.models.AgentPreset.prisma().update(
|
||||||
|
where={"id": preset_id, "userId": user_id},
|
||||||
|
data={"isDeleted": True},
|
||||||
|
)
|
||||||
|
except prisma.errors.PrismaError as e:
|
||||||
|
logger.error(f"Database error deleting preset: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to delete preset"
|
||||||
|
) from e
|
||||||
|
|
|
@ -37,7 +37,7 @@ async def test_get_library_agents(mocker):
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_library_agents = [
|
mock_library_agents = [
|
||||||
prisma.models.UserAgent(
|
prisma.models.LibraryAgent(
|
||||||
id="ua1",
|
id="ua1",
|
||||||
userId="test-user",
|
userId="test-user",
|
||||||
agentId="agent2",
|
agentId="agent2",
|
||||||
|
@ -48,6 +48,7 @@ async def test_get_library_agents(mocker):
|
||||||
createdAt=datetime.now(),
|
createdAt=datetime.now(),
|
||||||
updatedAt=datetime.now(),
|
updatedAt=datetime.now(),
|
||||||
isFavorite=False,
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
Agent=prisma.models.AgentGraph(
|
Agent=prisma.models.AgentGraph(
|
||||||
id="agent2",
|
id="agent2",
|
||||||
version=1,
|
version=1,
|
||||||
|
@ -67,8 +68,8 @@ async def test_get_library_agents(mocker):
|
||||||
return_value=mock_user_created
|
return_value=mock_user_created
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
mock_user_agent.return_value.find_many = mocker.AsyncMock(
|
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||||
return_value=mock_library_agents
|
return_value=mock_library_agents
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,40 +77,16 @@ async def test_get_library_agents(mocker):
|
||||||
result = await db.get_library_agents("test-user")
|
result = await db.get_library_agents("test-user")
|
||||||
|
|
||||||
# Verify results
|
# Verify results
|
||||||
assert len(result) == 2
|
assert len(result) == 1
|
||||||
assert result[0].id == "agent1"
|
assert result[0].id == "ua1"
|
||||||
assert result[0].name == "Test Agent 1"
|
assert result[0].name == "Test Agent 2"
|
||||||
assert result[0].description == "Test Description 1"
|
assert result[0].description == "Test Description 2"
|
||||||
assert result[0].isCreatedByUser is True
|
assert result[0].is_created_by_user is False
|
||||||
assert result[1].id == "agent2"
|
assert result[0].is_latest_version is True
|
||||||
assert result[1].name == "Test Agent 2"
|
assert result[0].is_favorite is False
|
||||||
assert result[1].description == "Test Description 2"
|
assert result[0].agent_id == "agent2"
|
||||||
assert result[1].isCreatedByUser is False
|
assert result[0].agent_version == 1
|
||||||
|
assert result[0].preset_id is None
|
||||||
# Verify mocks called correctly
|
|
||||||
mock_agent_graph.return_value.find_many.assert_called_once_with(
|
|
||||||
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
|
|
||||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
|
||||||
)
|
|
||||||
mock_user_agent.return_value.find_many.assert_called_once_with(
|
|
||||||
where=prisma.types.UserAgentWhereInput(
|
|
||||||
userId="test-user", isDeleted=False, isArchived=False
|
|
||||||
),
|
|
||||||
include={
|
|
||||||
"Agent": {
|
|
||||||
"include": {
|
|
||||||
"AgentNodes": {
|
|
||||||
"include": {
|
|
||||||
"Input": True,
|
|
||||||
"Output": True,
|
|
||||||
"Webhook": True,
|
|
||||||
"AgentBlock": True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -152,26 +129,26 @@ async def test_add_agent_to_library(mocker):
|
||||||
return_value=mock_store_listing
|
return_value=mock_store_listing
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||||
mock_user_agent.return_value.create = mocker.AsyncMock()
|
mock_library_agent.return_value.create = mocker.AsyncMock()
|
||||||
|
|
||||||
# Call function
|
# Call function
|
||||||
await db.add_agent_to_library("version123", "test-user")
|
await db.add_store_agent_to_library("version123", "test-user")
|
||||||
|
|
||||||
# Verify mocks called correctly
|
# Verify mocks called correctly
|
||||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||||
where={"id": "version123"}, include={"Agent": True}
|
where={"id": "version123"}, include={"Agent": True}
|
||||||
)
|
)
|
||||||
mock_user_agent.return_value.find_first.assert_called_once_with(
|
mock_library_agent.return_value.find_first.assert_called_once_with(
|
||||||
where={
|
where={
|
||||||
"userId": "test-user",
|
"userId": "test-user",
|
||||||
"agentId": "agent1",
|
"agentId": "agent1",
|
||||||
"agentVersion": 1,
|
"agentVersion": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
mock_user_agent.return_value.create.assert_called_once_with(
|
mock_library_agent.return_value.create.assert_called_once_with(
|
||||||
data=prisma.types.UserAgentCreateInput(
|
data=prisma.types.LibraryAgentCreateInput(
|
||||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -189,7 +166,7 @@ async def test_add_agent_to_library_not_found(mocker):
|
||||||
|
|
||||||
# Call function and verify exception
|
# Call function and verify exception
|
||||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||||
await db.add_agent_to_library("version123", "test-user")
|
await db.add_store_agent_to_library("version123", "test-user")
|
||||||
|
|
||||||
# Verify mock called correctly
|
# Verify mock called correctly
|
||||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||||
|
|
|
@ -1,16 +1,112 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
import prisma.models
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
import backend.data.block
|
||||||
|
import backend.data.graph
|
||||||
|
import backend.server.model
|
||||||
|
|
||||||
|
|
||||||
class LibraryAgent(pydantic.BaseModel):
|
class LibraryAgent(pydantic.BaseModel):
|
||||||
id: str # Changed from agent_id to match GraphMeta
|
id: str # Changed from agent_id to match GraphMeta
|
||||||
version: int # Changed from agent_version to match GraphMeta
|
|
||||||
is_active: bool # Added to match GraphMeta
|
agent_id: str
|
||||||
|
agent_version: int # Changed from agent_version to match GraphMeta
|
||||||
|
|
||||||
|
preset_id: str | None
|
||||||
|
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
isCreatedByUser: bool
|
|
||||||
# Made input_schema and output_schema match GraphMeta's type
|
# Made input_schema and output_schema match GraphMeta's type
|
||||||
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||||
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||||
|
|
||||||
|
is_favorite: bool
|
||||||
|
is_created_by_user: bool
|
||||||
|
|
||||||
|
is_latest_version: bool
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(agent: prisma.models.LibraryAgent):
|
||||||
|
if not agent.Agent:
|
||||||
|
raise ValueError("AgentGraph is required")
|
||||||
|
|
||||||
|
graph = backend.data.graph.GraphModel.from_db(agent.Agent)
|
||||||
|
|
||||||
|
agent_updated_at = agent.Agent.updatedAt
|
||||||
|
lib_agent_updated_at = agent.updatedAt
|
||||||
|
|
||||||
|
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
|
||||||
|
updated_at = (
|
||||||
|
max(agent_updated_at, lib_agent_updated_at)
|
||||||
|
if agent_updated_at
|
||||||
|
else lib_agent_updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return LibraryAgent(
|
||||||
|
id=agent.id,
|
||||||
|
agent_id=agent.agentId,
|
||||||
|
agent_version=agent.agentVersion,
|
||||||
|
updated_at=updated_at,
|
||||||
|
name=graph.name,
|
||||||
|
description=graph.description,
|
||||||
|
input_schema=graph.input_schema,
|
||||||
|
output_schema=graph.output_schema,
|
||||||
|
is_favorite=agent.isFavorite,
|
||||||
|
is_created_by_user=agent.isCreatedByUser,
|
||||||
|
is_latest_version=graph.is_active,
|
||||||
|
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryAgentPreset(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
|
agent_id: str
|
||||||
|
agent_version: int
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
|
||||||
|
is_active: bool
|
||||||
|
|
||||||
|
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(preset: prisma.models.AgentPreset):
|
||||||
|
input_data = {}
|
||||||
|
|
||||||
|
for data in preset.InputPresets or []:
|
||||||
|
input_data[data.name] = json.loads(data.data)
|
||||||
|
|
||||||
|
return LibraryAgentPreset(
|
||||||
|
id=preset.id,
|
||||||
|
updated_at=preset.updatedAt,
|
||||||
|
agent_id=preset.agentId,
|
||||||
|
agent_version=preset.agentVersion,
|
||||||
|
name=preset.name,
|
||||||
|
description=preset.description,
|
||||||
|
is_active=preset.isActive,
|
||||||
|
inputs=input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryAgentPresetResponse(pydantic.BaseModel):
|
||||||
|
presets: list[LibraryAgentPreset]
|
||||||
|
pagination: backend.server.model.Pagination
|
||||||
|
|
||||||
|
|
||||||
|
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
|
||||||
|
agent_id: str
|
||||||
|
agent_version: int
|
||||||
|
is_active: bool
|
||||||
|
|
|
@ -1,23 +1,35 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import prisma.models
|
||||||
|
|
||||||
|
import backend.data.block
|
||||||
|
import backend.server.model
|
||||||
import backend.server.v2.library.model
|
import backend.server.v2.library.model
|
||||||
|
|
||||||
|
|
||||||
def test_library_agent():
|
def test_library_agent():
|
||||||
agent = backend.server.v2.library.model.LibraryAgent(
|
agent = backend.server.v2.library.model.LibraryAgent(
|
||||||
id="test-agent-123",
|
id="test-agent-123",
|
||||||
version=1,
|
agent_id="agent-123",
|
||||||
is_active=True,
|
agent_version=1,
|
||||||
|
preset_id=None,
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
name="Test Agent",
|
name="Test Agent",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
isCreatedByUser=False,
|
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
output_schema={"type": "object", "properties": {}},
|
||||||
|
is_favorite=False,
|
||||||
|
is_created_by_user=False,
|
||||||
|
is_latest_version=True,
|
||||||
)
|
)
|
||||||
assert agent.id == "test-agent-123"
|
assert agent.id == "test-agent-123"
|
||||||
assert agent.version == 1
|
assert agent.agent_id == "agent-123"
|
||||||
assert agent.is_active is True
|
assert agent.agent_version == 1
|
||||||
assert agent.name == "Test Agent"
|
assert agent.name == "Test Agent"
|
||||||
assert agent.description == "Test description"
|
assert agent.description == "Test description"
|
||||||
assert agent.isCreatedByUser is False
|
assert agent.is_favorite is False
|
||||||
|
assert agent.is_created_by_user is False
|
||||||
|
assert agent.is_latest_version is True
|
||||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
@ -25,19 +37,148 @@ def test_library_agent():
|
||||||
def test_library_agent_with_user_created():
|
def test_library_agent_with_user_created():
|
||||||
agent = backend.server.v2.library.model.LibraryAgent(
|
agent = backend.server.v2.library.model.LibraryAgent(
|
||||||
id="user-agent-456",
|
id="user-agent-456",
|
||||||
version=2,
|
agent_id="agent-456",
|
||||||
is_active=True,
|
agent_version=2,
|
||||||
|
preset_id=None,
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
name="User Created Agent",
|
name="User Created Agent",
|
||||||
description="An agent created by the user",
|
description="An agent created by the user",
|
||||||
isCreatedByUser=True,
|
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
output_schema={"type": "object", "properties": {}},
|
||||||
|
is_favorite=False,
|
||||||
|
is_created_by_user=True,
|
||||||
|
is_latest_version=True,
|
||||||
)
|
)
|
||||||
assert agent.id == "user-agent-456"
|
assert agent.id == "user-agent-456"
|
||||||
assert agent.version == 2
|
assert agent.agent_id == "agent-456"
|
||||||
assert agent.is_active is True
|
assert agent.agent_version == 2
|
||||||
assert agent.name == "User Created Agent"
|
assert agent.name == "User Created Agent"
|
||||||
assert agent.description == "An agent created by the user"
|
assert agent.description == "An agent created by the user"
|
||||||
assert agent.isCreatedByUser is True
|
assert agent.is_favorite is False
|
||||||
|
assert agent.is_created_by_user is True
|
||||||
|
assert agent.is_latest_version is True
|
||||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_library_agent_preset():
|
||||||
|
preset = backend.server.v2.library.model.LibraryAgentPreset(
|
||||||
|
id="preset-123",
|
||||||
|
name="Test Preset",
|
||||||
|
description="Test preset description",
|
||||||
|
agent_id="test-agent-123",
|
||||||
|
agent_version=1,
|
||||||
|
is_active=True,
|
||||||
|
inputs={
|
||||||
|
"input1": backend.data.block.BlockInput(
|
||||||
|
name="input1",
|
||||||
|
data={"type": "string", "value": "test value"},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
assert preset.id == "preset-123"
|
||||||
|
assert preset.name == "Test Preset"
|
||||||
|
assert preset.description == "Test preset description"
|
||||||
|
assert preset.agent_id == "test-agent-123"
|
||||||
|
assert preset.agent_version == 1
|
||||||
|
assert preset.is_active is True
|
||||||
|
assert preset.inputs == {
|
||||||
|
"input1": backend.data.block.BlockInput(
|
||||||
|
name="input1", data={"type": "string", "value": "test value"}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_library_agent_preset_response():
|
||||||
|
preset = backend.server.v2.library.model.LibraryAgentPreset(
|
||||||
|
id="preset-123",
|
||||||
|
name="Test Preset",
|
||||||
|
description="Test preset description",
|
||||||
|
agent_id="test-agent-123",
|
||||||
|
agent_version=1,
|
||||||
|
is_active=True,
|
||||||
|
inputs={
|
||||||
|
"input1": backend.data.block.BlockInput(
|
||||||
|
name="input1",
|
||||||
|
data={"type": "string", "value": "test value"},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
updated_at=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
pagination = backend.server.model.Pagination(
|
||||||
|
total_items=1, total_pages=1, current_page=1, page_size=10
|
||||||
|
)
|
||||||
|
|
||||||
|
response = backend.server.v2.library.model.LibraryAgentPresetResponse(
|
||||||
|
presets=[preset], pagination=pagination
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.presets) == 1
|
||||||
|
assert response.presets[0].id == "preset-123"
|
||||||
|
assert response.pagination.total_items == 1
|
||||||
|
assert response.pagination.total_pages == 1
|
||||||
|
assert response.pagination.current_page == 1
|
||||||
|
assert response.pagination.page_size == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_library_agent_preset_request():
|
||||||
|
request = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||||
|
name="New Preset",
|
||||||
|
description="New preset description",
|
||||||
|
agent_id="agent-123",
|
||||||
|
agent_version=1,
|
||||||
|
is_active=True,
|
||||||
|
inputs={
|
||||||
|
"input1": backend.data.block.BlockInput(
|
||||||
|
name="input1",
|
||||||
|
data={"type": "string", "value": "test value"},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request.name == "New Preset"
|
||||||
|
assert request.description == "New preset description"
|
||||||
|
assert request.agent_id == "agent-123"
|
||||||
|
assert request.agent_version == 1
|
||||||
|
assert request.is_active is True
|
||||||
|
assert request.inputs == {
|
||||||
|
"input1": backend.data.block.BlockInput(
|
||||||
|
name="input1", data={"type": "string", "value": "test value"}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_library_agent_from_db():
|
||||||
|
# Create mock DB agent
|
||||||
|
db_agent = prisma.models.AgentPreset(
|
||||||
|
id="test-agent-123",
|
||||||
|
createdAt=datetime.datetime.now(),
|
||||||
|
updatedAt=datetime.datetime.now(),
|
||||||
|
agentId="agent-123",
|
||||||
|
agentVersion=1,
|
||||||
|
name="Test Agent",
|
||||||
|
description="Test agent description",
|
||||||
|
isActive=True,
|
||||||
|
userId="test-user-123",
|
||||||
|
isDeleted=False,
|
||||||
|
InputPresets=[
|
||||||
|
prisma.models.AgentNodeExecutionInputOutput(
|
||||||
|
id="input-123",
|
||||||
|
time=datetime.datetime.now(),
|
||||||
|
name="input1",
|
||||||
|
data='{"type": "string", "value": "test value"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to LibraryAgentPreset
|
||||||
|
agent = backend.server.v2.library.model.LibraryAgentPreset.from_db(db_agent)
|
||||||
|
|
||||||
|
assert agent.id == "test-agent-123"
|
||||||
|
assert agent.agent_version == 1
|
||||||
|
assert agent.is_active is True
|
||||||
|
assert agent.name == "Test Agent"
|
||||||
|
assert agent.description == "Test agent description"
|
||||||
|
assert agent.inputs == {"input1": {"type": "string", "value": "test value"}}
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
import logging
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import autogpt_libs.auth.depends
|
|
||||||
import autogpt_libs.auth.middleware
|
|
||||||
import fastapi
|
|
||||||
import prisma
|
|
||||||
|
|
||||||
import backend.data.graph
|
|
||||||
import backend.integrations.creds_manager
|
|
||||||
import backend.integrations.webhooks.graph_lifecycle_hooks
|
|
||||||
import backend.server.v2.library.db
|
|
||||||
import backend.server.v2.library.model
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = fastapi.APIRouter()
|
|
||||||
integration_creds_manager = (
|
|
||||||
backend.integrations.creds_manager.IntegrationCredentialsManager()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/agents",
|
|
||||||
tags=["library", "private"],
|
|
||||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
|
||||||
)
|
|
||||||
async def get_library_agents(
|
|
||||||
user_id: typing.Annotated[
|
|
||||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
|
||||||
]
|
|
||||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
|
||||||
"""
|
|
||||||
Get all agents in the user's library, including both created and saved agents.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
|
||||||
return agents
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception occurred whilst getting library agents")
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500, detail="Failed to get library agents"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/agents/{store_listing_version_id}",
|
|
||||||
tags=["library", "private"],
|
|
||||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
|
||||||
status_code=201,
|
|
||||||
)
|
|
||||||
async def add_agent_to_library(
|
|
||||||
store_listing_version_id: str,
|
|
||||||
user_id: typing.Annotated[
|
|
||||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
|
||||||
],
|
|
||||||
) -> fastapi.Response:
|
|
||||||
"""
|
|
||||||
Add an agent from the store to the user's library.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
store_listing_version_id (str): ID of the store listing version to add
|
|
||||||
user_id (str): ID of the authenticated user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
fastapi.Response: 201 status code on success
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If there is an error adding the agent to the library
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get the graph from the store listing
|
|
||||||
store_listing_version = (
|
|
||||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
|
||||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not store_listing_version or not store_listing_version.Agent:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Store listing version {store_listing_version_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
agent = store_listing_version.Agent
|
|
||||||
|
|
||||||
if agent.userId == user_id:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400, detail="Cannot add own agent to library"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a new graph from the template
|
|
||||||
graph = await backend.data.graph.get_graph(
|
|
||||||
agent.id, agent.version, template=True, user_id=user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not graph:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=404, detail=f"Agent {agent.id} not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a deep copy with new IDs
|
|
||||||
graph.version = 1
|
|
||||||
graph.is_template = False
|
|
||||||
graph.is_active = True
|
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
|
||||||
|
|
||||||
# Save the new graph
|
|
||||||
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
|
|
||||||
graph = (
|
|
||||||
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
|
|
||||||
graph,
|
|
||||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return fastapi.Response(status_code=201)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Exception occurred whilst adding agent to library")
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500, detail="Failed to add agent to library"
|
|
||||||
)
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
import fastapi
|
||||||
|
|
||||||
|
from .agents import router as agents_router
|
||||||
|
from .presets import router as presets_router
|
||||||
|
|
||||||
|
router = fastapi.APIRouter()
|
||||||
|
|
||||||
|
router.include_router(presets_router)
|
||||||
|
router.include_router(agents_router)
|
|
@ -0,0 +1,148 @@
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import autogpt_libs.auth.depends
|
||||||
|
import autogpt_libs.auth.middleware
|
||||||
|
import autogpt_libs.utils.cache
|
||||||
|
import fastapi
|
||||||
|
|
||||||
|
import backend.server.v2.library.db
|
||||||
|
import backend.server.v2.library.model
|
||||||
|
import backend.server.v2.store.exceptions
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/agents",
|
||||||
|
tags=["library", "private"],
|
||||||
|
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||||
|
)
|
||||||
|
async def get_library_agents(
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
]
|
||||||
|
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||||
|
"""
|
||||||
|
Get all agents in the user's library, including both created and saved agents.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||||
|
return agents
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred whilst getting library agents: {e}")
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500, detail="Failed to get library agents"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/agents/{store_listing_version_id}",
|
||||||
|
tags=["library", "private"],
|
||||||
|
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
async def add_agent_to_library(
|
||||||
|
store_listing_version_id: str,
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
) -> fastapi.Response:
|
||||||
|
"""
|
||||||
|
Add an agent from the store to the user's library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_listing_version_id (str): ID of the store listing version to add
|
||||||
|
user_id (str): ID of the authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fastapi.Response: 201 status code on success
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If there is an error adding the agent to the library
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use the database function to add the agent to the library
|
||||||
|
await backend.server.v2.library.db.add_store_agent_to_library(
|
||||||
|
store_listing_version_id, user_id
|
||||||
|
)
|
||||||
|
return fastapi.Response(status_code=201)
|
||||||
|
|
||||||
|
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
|
)
|
||||||
|
except backend.server.v2.store.exceptions.DatabaseError as e:
|
||||||
|
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500, detail="Failed to add agent to library"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Unexpected exception occurred whilst adding agent to library: {e}"
|
||||||
|
)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500, detail="Failed to add agent to library"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/agents/{library_agent_id}",
|
||||||
|
tags=["library", "private"],
|
||||||
|
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||||
|
status_code=204,
|
||||||
|
)
|
||||||
|
async def update_library_agent(
|
||||||
|
library_agent_id: str,
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
auto_update_version: bool = False,
|
||||||
|
is_favorite: bool = False,
|
||||||
|
is_archived: bool = False,
|
||||||
|
is_deleted: bool = False,
|
||||||
|
) -> fastapi.Response:
|
||||||
|
"""
|
||||||
|
Update the library agent with the given fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
library_agent_id (str): ID of the library agent to update
|
||||||
|
user_id (str): ID of the authenticated user
|
||||||
|
auto_update_version (bool): Whether to auto-update the agent version
|
||||||
|
is_favorite (bool): Whether the agent is marked as favorite
|
||||||
|
is_archived (bool): Whether the agent is archived
|
||||||
|
is_deleted (bool): Whether the agent is deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
fastapi.Response: 204 status code on success
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If there is an error updating the library agent
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use the database function to update the library agent
|
||||||
|
await backend.server.v2.library.db.update_library_agent(
|
||||||
|
library_agent_id,
|
||||||
|
user_id,
|
||||||
|
auto_update_version,
|
||||||
|
is_favorite,
|
||||||
|
is_archived,
|
||||||
|
is_deleted,
|
||||||
|
)
|
||||||
|
return fastapi.Response(status_code=204)
|
||||||
|
|
||||||
|
except backend.server.v2.store.exceptions.DatabaseError as e:
|
||||||
|
logger.exception(f"Database error occurred whilst updating library agent: {e}")
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500, detail="Failed to update library agent"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Unexpected exception occurred whilst updating library agent: {e}"
|
||||||
|
)
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=500, detail="Failed to update library agent"
|
||||||
|
)
|
|
@ -0,0 +1,156 @@
|
||||||
|
import logging
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import autogpt_libs.auth.depends
|
||||||
|
import autogpt_libs.auth.middleware
|
||||||
|
import autogpt_libs.utils.cache
|
||||||
|
import fastapi
|
||||||
|
|
||||||
|
import backend.data.graph
|
||||||
|
import backend.executor
|
||||||
|
import backend.integrations.creds_manager
|
||||||
|
import backend.integrations.webhooks.graph_lifecycle_hooks
|
||||||
|
import backend.server.v2.library.db
|
||||||
|
import backend.server.v2.library.model
|
||||||
|
import backend.util.service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = fastapi.APIRouter()
|
||||||
|
integration_creds_manager = (
|
||||||
|
backend.integrations.creds_manager.IntegrationCredentialsManager()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@autogpt_libs.utils.cache.thread_cached
|
||||||
|
def execution_manager_client() -> backend.executor.ExecutionManager:
|
||||||
|
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/presets")
|
||||||
|
async def get_presets(
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 10,
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
|
||||||
|
try:
|
||||||
|
presets = await backend.server.v2.library.db.get_presets(
|
||||||
|
user_id, page, page_size
|
||||||
|
)
|
||||||
|
return presets
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred whilst getting presets: {e}")
|
||||||
|
raise fastapi.HTTPException(status_code=500, detail="Failed to get presets")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/presets/{preset_id}")
|
||||||
|
async def get_preset(
|
||||||
|
preset_id: str,
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||||
|
try:
|
||||||
|
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
|
||||||
|
if not preset:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Preset {preset_id} not found",
|
||||||
|
)
|
||||||
|
return preset
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred whilst getting preset: {e}")
|
||||||
|
raise fastapi.HTTPException(status_code=500, detail="Failed to get preset")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/presets")
|
||||||
|
async def create_preset(
|
||||||
|
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||||
|
try:
|
||||||
|
return await backend.server.v2.library.db.create_or_update_preset(
|
||||||
|
user_id, preset
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred whilst creating preset: {e}")
|
||||||
|
raise fastapi.HTTPException(status_code=500, detail="Failed to create preset")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/presets/{preset_id}")
|
||||||
|
async def update_preset(
|
||||||
|
preset_id: str,
|
||||||
|
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
) -> backend.server.v2.library.model.LibraryAgentPreset:
|
||||||
|
try:
|
||||||
|
return await backend.server.v2.library.db.create_or_update_preset(
|
||||||
|
user_id, preset, preset_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred whilst updating preset: {e}")
|
||||||
|
raise fastapi.HTTPException(status_code=500, detail="Failed to update preset")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/presets/{preset_id}")
|
||||||
|
async def delete_preset(
|
||||||
|
preset_id: str,
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
await backend.server.v2.library.db.delete_preset(user_id, preset_id)
|
||||||
|
return fastapi.Response(status_code=204)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Exception occurred whilst deleting preset: {e}")
|
||||||
|
raise fastapi.HTTPException(status_code=500, detail="Failed to delete preset")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
path="/presets/{preset_id}/execute",
|
||||||
|
tags=["presets"],
|
||||||
|
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||||
|
)
|
||||||
|
async def execute_preset(
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
preset_id: str,
|
||||||
|
node_input: dict[typing.Any, typing.Any],
|
||||||
|
user_id: typing.Annotated[
|
||||||
|
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||||
|
],
|
||||||
|
) -> dict[str, typing.Any]: # FIXME: add proper return type
|
||||||
|
try:
|
||||||
|
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
|
||||||
|
if not preset:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
|
||||||
|
|
||||||
|
logger.info(f"Preset inputs: {preset.inputs}")
|
||||||
|
|
||||||
|
updated_node_input = node_input.copy()
|
||||||
|
# Merge in preset input values
|
||||||
|
for key, value in preset.inputs.items():
|
||||||
|
if key not in updated_node_input:
|
||||||
|
updated_node_input[key] = value
|
||||||
|
|
||||||
|
execution = execution_manager_client().add_execution(
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
data=updated_node_input,
|
||||||
|
user_id=user_id,
|
||||||
|
preset_id=preset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Execution added: {execution} with input: {updated_node_input}")
|
||||||
|
|
||||||
|
return {"id": execution.graph_exec_id}
|
||||||
|
except Exception as e:
|
||||||
|
msg = e.__str__().encode().decode("unicode_escape")
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail=msg)
|
|
@ -1,3 +1,5 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
import autogpt_libs.auth.depends
|
import autogpt_libs.auth.depends
|
||||||
import autogpt_libs.auth.middleware
|
import autogpt_libs.auth.middleware
|
||||||
import fastapi
|
import fastapi
|
||||||
|
@ -35,21 +37,29 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||||
mocked_value = [
|
mocked_value = [
|
||||||
backend.server.v2.library.model.LibraryAgent(
|
backend.server.v2.library.model.LibraryAgent(
|
||||||
id="test-agent-1",
|
id="test-agent-1",
|
||||||
version=1,
|
agent_id="test-agent-1",
|
||||||
is_active=True,
|
agent_version=1,
|
||||||
|
preset_id="preset-1",
|
||||||
|
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||||
|
is_favorite=False,
|
||||||
|
is_created_by_user=True,
|
||||||
|
is_latest_version=True,
|
||||||
name="Test Agent 1",
|
name="Test Agent 1",
|
||||||
description="Test Description 1",
|
description="Test Description 1",
|
||||||
isCreatedByUser=True,
|
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
output_schema={"type": "object", "properties": {}},
|
||||||
),
|
),
|
||||||
backend.server.v2.library.model.LibraryAgent(
|
backend.server.v2.library.model.LibraryAgent(
|
||||||
id="test-agent-2",
|
id="test-agent-2",
|
||||||
version=1,
|
agent_id="test-agent-2",
|
||||||
is_active=True,
|
agent_version=1,
|
||||||
|
preset_id="preset-2",
|
||||||
|
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||||
|
is_favorite=False,
|
||||||
|
is_created_by_user=False,
|
||||||
|
is_latest_version=True,
|
||||||
name="Test Agent 2",
|
name="Test Agent 2",
|
||||||
description="Test Description 2",
|
description="Test Description 2",
|
||||||
isCreatedByUser=False,
|
|
||||||
input_schema={"type": "object", "properties": {}},
|
input_schema={"type": "object", "properties": {}},
|
||||||
output_schema={"type": "object", "properties": {}},
|
output_schema={"type": "object", "properties": {}},
|
||||||
),
|
),
|
||||||
|
@ -65,10 +75,10 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||||
for agent in response.json()
|
for agent in response.json()
|
||||||
]
|
]
|
||||||
assert len(data) == 2
|
assert len(data) == 2
|
||||||
assert data[0].id == "test-agent-1"
|
assert data[0].agent_id == "test-agent-1"
|
||||||
assert data[0].isCreatedByUser is True
|
assert data[0].is_created_by_user is True
|
||||||
assert data[1].id == "test-agent-2"
|
assert data[1].agent_id == "test-agent-2"
|
||||||
assert data[1].isCreatedByUser is False
|
assert data[1].is_created_by_user is False
|
||||||
mock_db_call.assert_called_once_with("test-user-id")
|
mock_db_call.assert_called_once_with("test-user-id")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -325,7 +325,10 @@ async def get_store_submissions(
|
||||||
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
|
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
|
||||||
# Query submissions from database
|
# Query submissions from database
|
||||||
submissions = await prisma.models.StoreSubmission.prisma().find_many(
|
submissions = await prisma.models.StoreSubmission.prisma().find_many(
|
||||||
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}]
|
where=where,
|
||||||
|
skip=skip,
|
||||||
|
take=page_size,
|
||||||
|
order=[{"date_submitted": "desc"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get total count for pagination
|
# Get total count for pagination
|
||||||
|
@ -405,9 +408,7 @@ async def delete_store_submission(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete the submission
|
# Delete the submission
|
||||||
await prisma.models.StoreListing.prisma().delete(
|
await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
|
||||||
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Successfully deleted submission {submission_id} for user {user_id}"
|
f"Successfully deleted submission {submission_id} for user {user_id}"
|
||||||
|
@ -504,7 +505,15 @@ async def create_store_submission(
|
||||||
"subHeading": sub_heading,
|
"subHeading": sub_heading,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
|
include={"StoreListingVersions": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
slv_id = (
|
||||||
|
listing.StoreListingVersions[0].id
|
||||||
|
if listing.StoreListingVersions is not None
|
||||||
|
and len(listing.StoreListingVersions) > 0
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Created store listing for agent {agent_id}")
|
logger.debug(f"Created store listing for agent {agent_id}")
|
||||||
|
@ -521,6 +530,7 @@ async def create_store_submission(
|
||||||
status=prisma.enums.SubmissionStatus.PENDING,
|
status=prisma.enums.SubmissionStatus.PENDING,
|
||||||
runs=0,
|
runs=0,
|
||||||
rating=0.0,
|
rating=0.0,
|
||||||
|
store_listing_version_id=slv_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except (
|
except (
|
||||||
|
@ -811,9 +821,7 @@ async def get_agent(
|
||||||
|
|
||||||
agent = store_listing_version.Agent
|
agent = store_listing_version.Agent
|
||||||
|
|
||||||
graph = await backend.data.graph.get_graph(
|
graph = await backend.data.graph.get_graph(agent.id, agent.version)
|
||||||
agent.id, agent.version, template=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not graph:
|
if not graph:
|
||||||
raise fastapi.HTTPException(
|
raise fastapi.HTTPException(
|
||||||
|
@ -832,3 +840,74 @@ async def get_agent(
|
||||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
"Failed to fetch agent"
|
"Failed to fetch agent"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def review_store_submission(
|
||||||
|
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
|
||||||
|
) -> prisma.models.StoreListingSubmission:
|
||||||
|
"""Review a store listing submission."""
|
||||||
|
try:
|
||||||
|
store_listing_version = (
|
||||||
|
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||||
|
where={"id": store_listing_version_id},
|
||||||
|
include={"StoreListing": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not store_listing_version or not store_listing_version.StoreListing:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Store listing version {store_listing_version_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
status = (
|
||||||
|
prisma.enums.SubmissionStatus.APPROVED
|
||||||
|
if is_approved
|
||||||
|
else prisma.enums.SubmissionStatus.REJECTED
|
||||||
|
)
|
||||||
|
|
||||||
|
create_data = prisma.types.StoreListingSubmissionCreateInput(
|
||||||
|
StoreListingVersion={"connect": {"id": store_listing_version_id}},
|
||||||
|
Status=status,
|
||||||
|
reviewComments=comments,
|
||||||
|
Reviewer={"connect": {"id": reviewer_id}},
|
||||||
|
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
|
||||||
|
createdAt=datetime.now(),
|
||||||
|
updatedAt=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = prisma.types.StoreListingSubmissionUpdateInput(
|
||||||
|
Status=status,
|
||||||
|
reviewComments=comments,
|
||||||
|
Reviewer={"connect": {"id": reviewer_id}},
|
||||||
|
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
|
||||||
|
updatedAt=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_approved:
|
||||||
|
await prisma.models.StoreListing.prisma().update(
|
||||||
|
where={"id": store_listing_version.StoreListing.id},
|
||||||
|
data={"isApproved": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
submission = await prisma.models.StoreListingSubmission.prisma().upsert(
|
||||||
|
where={"storeListingVersionId": store_listing_version_id},
|
||||||
|
data=prisma.types.StoreListingSubmissionUpsertInput(
|
||||||
|
create=create_data,
|
||||||
|
update=update_data,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not submission:
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Store listing submission {store_listing_version_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return submission
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reviewing store submission: {str(e)}")
|
||||||
|
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||||
|
"Failed to review store submission"
|
||||||
|
) from e
|
||||||
|
|
|
@ -115,6 +115,7 @@ class StoreSubmission(pydantic.BaseModel):
|
||||||
status: prisma.enums.SubmissionStatus
|
status: prisma.enums.SubmissionStatus
|
||||||
runs: int
|
runs: int
|
||||||
rating: float
|
rating: float
|
||||||
|
store_listing_version_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class StoreSubmissionsResponse(pydantic.BaseModel):
|
class StoreSubmissionsResponse(pydantic.BaseModel):
|
||||||
|
@ -151,3 +152,9 @@ class StoreReviewCreate(pydantic.BaseModel):
|
||||||
store_listing_version_id: str
|
store_listing_version_id: str
|
||||||
score: int
|
score: int
|
||||||
comments: str | None = None
|
comments: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewSubmissionRequest(pydantic.BaseModel):
|
||||||
|
store_listing_version_id: str
|
||||||
|
isApproved: bool
|
||||||
|
comments: str
|
||||||
|
|
|
@ -642,3 +642,33 @@ async def download_agent_file(
|
||||||
return fastapi.responses.FileResponse(
|
return fastapi.responses.FileResponse(
|
||||||
tmp_file.name, filename=file_name, media_type="application/json"
|
tmp_file.name, filename=file_name, media_type="application/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/submissions/review/{store_listing_version_id}",
|
||||||
|
tags=["store", "private"],
|
||||||
|
)
|
||||||
|
async def review_submission(
|
||||||
|
request: backend.server.v2.store.model.ReviewSubmissionRequest,
|
||||||
|
user: typing.Annotated[
|
||||||
|
autogpt_libs.auth.models.User,
|
||||||
|
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
|
||||||
|
],
|
||||||
|
):
|
||||||
|
# Proceed with the review submission logic
|
||||||
|
try:
|
||||||
|
submission = await backend.server.v2.store.db.review_store_submission(
|
||||||
|
store_listing_version_id=request.store_listing_version_id,
|
||||||
|
is_approved=request.isApproved,
|
||||||
|
comments=request.comments,
|
||||||
|
reviewer_id=user.user_id,
|
||||||
|
)
|
||||||
|
return submission
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Exception occurred whilst reviewing store submission")
|
||||||
|
return fastapi.responses.JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"detail": "An error occurred while reviewing the store submission"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -253,7 +253,7 @@ async def block_autogen_agent():
|
||||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||||
response = await server.agent_server.test_execute_graph(
|
response = await server.agent_server.test_execute_graph(
|
||||||
test_graph.id, input_data, test_user.id
|
test_graph.id, test_graph.version, input_data, test_user.id
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
result = await wait_execution(
|
result = await wait_execution(
|
||||||
|
|
|
@ -157,7 +157,7 @@ async def reddit_marketing_agent():
|
||||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||||
input_data = {"subreddit": "AutoGPT"}
|
input_data = {"subreddit": "AutoGPT"}
|
||||||
response = await server.agent_server.test_execute_graph(
|
response = await server.agent_server.test_execute_graph(
|
||||||
test_graph.id, input_data, test_user.id
|
test_graph.id, test_graph.version, input_data, test_user.id
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
|
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
|
||||||
|
|
|
@ -8,12 +8,19 @@ from backend.data.user import get_or_create_user
|
||||||
from backend.util.test import SpinTestServer, wait_execution
|
from backend.util.test import SpinTestServer, wait_execution
|
||||||
|
|
||||||
|
|
||||||
async def create_test_user() -> User:
|
async def create_test_user(alt_user: bool = False) -> User:
|
||||||
test_user_data = {
|
if alt_user:
|
||||||
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
test_user_data = {
|
||||||
"email": "testuser#example.com",
|
"sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||||
"name": "Test User",
|
"email": "testuser2#example.com",
|
||||||
}
|
"name": "Test User 2",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
test_user_data = {
|
||||||
|
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
|
||||||
|
"email": "testuser#example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
user = await get_or_create_user(test_user_data)
|
user = await get_or_create_user(test_user_data)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
@ -79,7 +86,7 @@ async def sample_agent():
|
||||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||||
response = await server.agent_server.test_execute_graph(
|
response = await server.agent_server.test_execute_graph(
|
||||||
test_graph.id, input_data, test_user.id
|
test_graph.id, test_graph.version, input_data, test_user.id
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
|
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
|
||||||
|
|
|
@ -153,6 +153,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||||
description="The name of the Google Cloud Storage bucket for media files",
|
description="The name of the Google Cloud Storage bucket for media files",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reddit_user_agent: str = Field(
|
||||||
|
default="AutoGPT:1.0 (by /u/autogpt)",
|
||||||
|
description="The user agent for the Reddit API",
|
||||||
|
)
|
||||||
|
|
||||||
scheduler_db_pool_size: int = Field(
|
scheduler_db_pool_size: int = Field(
|
||||||
default=3,
|
default=3,
|
||||||
description="The pool size for the scheduler database connection pool",
|
description="The pool size for the scheduler database connection pool",
|
||||||
|
@ -276,8 +281,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||||
|
|
||||||
reddit_client_id: str = Field(default="", description="Reddit client ID")
|
reddit_client_id: str = Field(default="", description="Reddit client ID")
|
||||||
reddit_client_secret: str = Field(default="", description="Reddit client secret")
|
reddit_client_secret: str = Field(default="", description="Reddit client secret")
|
||||||
reddit_username: str = Field(default="", description="Reddit username")
|
|
||||||
reddit_password: str = Field(default="", description="Reddit password")
|
|
||||||
|
|
||||||
openweathermap_api_key: str = Field(
|
openweathermap_api_key: str = Field(
|
||||||
default="", description="OpenWeatherMap API key"
|
default="", description="OpenWeatherMap API key"
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "AgentPreset" ADD COLUMN "isDeleted" BOOLEAN NOT NULL DEFAULT false;
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- You are about to drop the `UserAgent` table. If the table is not empty, all the data it contains will be lost.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentId_agentVersion_fkey";
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentPresetId_fkey";
|
||||||
|
|
||||||
|
-- DropForeignKey
|
||||||
|
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_userId_fkey";
|
||||||
|
|
||||||
|
-- DropTable
|
||||||
|
DROP TABLE "UserAgent";
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "LibraryAgent" (
|
||||||
|
"id" TEXT NOT NULL,
|
||||||
|
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"userId" TEXT NOT NULL,
|
||||||
|
"agentId" TEXT NOT NULL,
|
||||||
|
"agentVersion" INTEGER NOT NULL,
|
||||||
|
"agentPresetId" TEXT,
|
||||||
|
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"isArchived" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
|
||||||
|
CONSTRAINT "LibraryAgent_pkey" PRIMARY KEY ("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "LibraryAgent_userId_idx" ON "LibraryAgent"("userId");
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||||
|
|
||||||
|
-- AddForeignKey
|
||||||
|
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
|
@ -0,0 +1,2 @@
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LibraryAgent" ADD COLUMN "useGraphIsActiveVersion" BOOLEAN NOT NULL DEFAULT false;
|
|
@ -0,0 +1,29 @@
|
||||||
|
/*
|
||||||
|
Warnings:
|
||||||
|
|
||||||
|
- A unique constraint covering the columns `[agentId]` on the table `StoreListing` will be added. If there are existing duplicate values, this will fail.
|
||||||
|
|
||||||
|
*/
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "StoreListing_agentId_idx";
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "StoreListing_isApproved_idx";
|
||||||
|
|
||||||
|
-- DropIndex
|
||||||
|
DROP INDEX "StoreListingVersion_agentId_agentVersion_isApproved_idx";
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "StoreListing_agentId_owningUserId_idx" ON "StoreListing"("agentId", "owningUserId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "StoreListing_isDeleted_isApproved_idx" ON "StoreListing"("isDeleted", "isApproved");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "StoreListing_isDeleted_idx" ON "StoreListing"("isDeleted");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "StoreListing_agentId_key" ON "StoreListing"("agentId");
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE INDEX "StoreListingVersion_agentId_agentVersion_isDeleted_idx" ON "StoreListingVersion"("agentId", "agentVersion", "isDeleted");
|
|
@ -30,7 +30,7 @@ model User {
|
||||||
CreditTransaction CreditTransaction[]
|
CreditTransaction CreditTransaction[]
|
||||||
|
|
||||||
AgentPreset AgentPreset[]
|
AgentPreset AgentPreset[]
|
||||||
UserAgent UserAgent[]
|
LibraryAgent LibraryAgent[]
|
||||||
|
|
||||||
Profile Profile[]
|
Profile Profile[]
|
||||||
StoreListing StoreListing[]
|
StoreListing StoreListing[]
|
||||||
|
@ -65,7 +65,7 @@ model AgentGraph {
|
||||||
AgentGraphExecution AgentGraphExecution[]
|
AgentGraphExecution AgentGraphExecution[]
|
||||||
|
|
||||||
AgentPreset AgentPreset[]
|
AgentPreset AgentPreset[]
|
||||||
UserAgent UserAgent[]
|
LibraryAgent LibraryAgent[]
|
||||||
StoreListing StoreListing[]
|
StoreListing StoreListing[]
|
||||||
StoreListingVersion StoreListingVersion?
|
StoreListingVersion StoreListingVersion?
|
||||||
|
|
||||||
|
@ -103,15 +103,17 @@ model AgentPreset {
|
||||||
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
|
||||||
|
|
||||||
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
|
||||||
UserAgents UserAgent[]
|
LibraryAgents LibraryAgent[]
|
||||||
AgentExecution AgentGraphExecution[]
|
AgentExecution AgentGraphExecution[]
|
||||||
|
|
||||||
|
isDeleted Boolean @default(false)
|
||||||
|
|
||||||
@@index([userId])
|
@@index([userId])
|
||||||
}
|
}
|
||||||
|
|
||||||
// For the library page
|
// For the library page
|
||||||
// It is a user controlled list of agents, that they will see in there library
|
// It is a user controlled list of agents, that they will see in there library
|
||||||
model UserAgent {
|
model LibraryAgent {
|
||||||
id String @id @default(uuid())
|
id String @id @default(uuid())
|
||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
updatedAt DateTime @default(now()) @updatedAt
|
updatedAt DateTime @default(now()) @updatedAt
|
||||||
|
@ -126,6 +128,8 @@ model UserAgent {
|
||||||
agentPresetId String?
|
agentPresetId String?
|
||||||
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
|
||||||
|
|
||||||
|
useGraphIsActiveVersion Boolean @default(false)
|
||||||
|
|
||||||
isFavorite Boolean @default(false)
|
isFavorite Boolean @default(false)
|
||||||
isCreatedByUser Boolean @default(false)
|
isCreatedByUser Boolean @default(false)
|
||||||
isArchived Boolean @default(false)
|
isArchived Boolean @default(false)
|
||||||
|
@ -235,7 +239,7 @@ model AgentGraphExecution {
|
||||||
|
|
||||||
AgentNodeExecutions AgentNodeExecution[]
|
AgentNodeExecutions AgentNodeExecution[]
|
||||||
|
|
||||||
// Link to User model
|
// Link to User model -- Executed by this user
|
||||||
userId String
|
userId String
|
||||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||||
|
|
||||||
|
@ -443,6 +447,8 @@ view Creator {
|
||||||
agent_rating Float
|
agent_rating Float
|
||||||
agent_runs Int
|
agent_runs Int
|
||||||
is_featured Boolean
|
is_featured Boolean
|
||||||
|
|
||||||
|
// Index or unique are not applied to views
|
||||||
}
|
}
|
||||||
|
|
||||||
view StoreAgent {
|
view StoreAgent {
|
||||||
|
@ -465,11 +471,7 @@ view StoreAgent {
|
||||||
rating Float
|
rating Float
|
||||||
versions String[]
|
versions String[]
|
||||||
|
|
||||||
@@unique([creator_username, slug])
|
// Index or unique are not applied to views
|
||||||
@@index([creator_username])
|
|
||||||
@@index([featured])
|
|
||||||
@@index([categories])
|
|
||||||
@@index([storeListingVersionId])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
view StoreSubmission {
|
view StoreSubmission {
|
||||||
|
@ -487,7 +489,7 @@ view StoreSubmission {
|
||||||
agent_id String
|
agent_id String
|
||||||
agent_version Int
|
agent_version Int
|
||||||
|
|
||||||
@@index([user_id])
|
// Index or unique are not applied to views
|
||||||
}
|
}
|
||||||
|
|
||||||
model StoreListing {
|
model StoreListing {
|
||||||
|
@ -510,9 +512,13 @@ model StoreListing {
|
||||||
StoreListingVersions StoreListingVersion[]
|
StoreListingVersions StoreListingVersion[]
|
||||||
StoreListingSubmission StoreListingSubmission[]
|
StoreListingSubmission StoreListingSubmission[]
|
||||||
|
|
||||||
@@index([isApproved])
|
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||||
@@index([agentId])
|
@@unique([agentId])
|
||||||
|
@@index([agentId, owningUserId])
|
||||||
@@index([owningUserId])
|
@@index([owningUserId])
|
||||||
|
// Used in the view query
|
||||||
|
@@index([isDeleted, isApproved])
|
||||||
|
@@index([isDeleted])
|
||||||
}
|
}
|
||||||
|
|
||||||
model StoreListingVersion {
|
model StoreListingVersion {
|
||||||
|
@ -553,7 +559,7 @@ model StoreListingVersion {
|
||||||
StoreListingReview StoreListingReview[]
|
StoreListingReview StoreListingReview[]
|
||||||
|
|
||||||
@@unique([agentId, agentVersion])
|
@@unique([agentId, agentVersion])
|
||||||
@@index([agentId, agentVersion, isApproved])
|
@@index([agentId, agentVersion, isDeleted])
|
||||||
}
|
}
|
||||||
|
|
||||||
model StoreListingReview {
|
model StoreListingReview {
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import autogpt_libs.auth.models
|
||||||
|
import fastapi.exceptions
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import backend.server.v2.store.model
|
||||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
|
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
|
||||||
from backend.data.block import BlockSchema
|
from backend.data.block import BlockSchema
|
||||||
from backend.data.graph import Graph, Link, Node
|
from backend.data.graph import Graph, Link, Node
|
||||||
|
@ -202,3 +205,92 @@ async def test_clean_graph(server: SpinTestServer):
|
||||||
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
|
||||||
)
|
)
|
||||||
assert input_node.input_default["value"] == ""
|
assert input_node.input_default["value"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="session")
|
||||||
|
async def test_access_store_listing_graph(server: SpinTestServer):
|
||||||
|
"""
|
||||||
|
Test the access of a store listing graph.
|
||||||
|
"""
|
||||||
|
graph = Graph(
|
||||||
|
id="test_clean_graph",
|
||||||
|
name="Test Clean Graph",
|
||||||
|
description="Test graph cleaning",
|
||||||
|
nodes=[
|
||||||
|
Node(
|
||||||
|
id="input_node",
|
||||||
|
block_id=AgentInputBlock().id,
|
||||||
|
input_default={
|
||||||
|
"name": "test_input",
|
||||||
|
"value": "test value",
|
||||||
|
"description": "Test input description",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
links=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create graph and get model
|
||||||
|
create_graph = CreateGraph(graph=graph)
|
||||||
|
created_graph = await server.agent_server.test_create_graph(
|
||||||
|
create_graph, DEFAULT_USER_ID
|
||||||
|
)
|
||||||
|
|
||||||
|
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||||
|
agent_id=created_graph.id,
|
||||||
|
agent_version=created_graph.version,
|
||||||
|
slug="test-slug",
|
||||||
|
name="Test name",
|
||||||
|
sub_heading="Test sub heading",
|
||||||
|
video_url=None,
|
||||||
|
image_urls=[],
|
||||||
|
description="Test description",
|
||||||
|
categories=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# First we check the graph an not be accessed by a different user
|
||||||
|
with pytest.raises(fastapi.exceptions.HTTPException) as exc_info:
|
||||||
|
await server.agent_server.test_get_graph(
|
||||||
|
created_graph.id,
|
||||||
|
created_graph.version,
|
||||||
|
"3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||||
|
)
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
assert "Graph" in str(exc_info.value.detail)
|
||||||
|
|
||||||
|
# Now we create a store listing
|
||||||
|
store_listing = await server.agent_server.test_create_store_listing(
|
||||||
|
store_submission_request, DEFAULT_USER_ID
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(store_listing, fastapi.responses.JSONResponse):
|
||||||
|
assert False, "Failed to create store listing"
|
||||||
|
|
||||||
|
slv_id = (
|
||||||
|
store_listing.store_listing_version_id
|
||||||
|
if store_listing.store_listing_version_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert slv_id is not None
|
||||||
|
|
||||||
|
admin = autogpt_libs.auth.models.User(
|
||||||
|
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||||
|
role="admin",
|
||||||
|
email="admin@example.com",
|
||||||
|
phone_number="1234567890",
|
||||||
|
)
|
||||||
|
await server.agent_server.test_review_store_listing(
|
||||||
|
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||||
|
store_listing_version_id=slv_id,
|
||||||
|
isApproved=True,
|
||||||
|
comments="Test comments",
|
||||||
|
),
|
||||||
|
admin,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now we check the graph can be accessed by a user that does not own the graph
|
||||||
|
got_graph = await server.agent_server.test_get_graph(
|
||||||
|
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||||
|
)
|
||||||
|
assert got_graph is not None
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import autogpt_libs.auth.models
|
||||||
|
import fastapi.responses
|
||||||
import pytest
|
import pytest
|
||||||
from prisma.models import User
|
from prisma.models import User
|
||||||
|
|
||||||
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
import backend.server.v2.library.model
|
||||||
|
import backend.server.v2.store.model
|
||||||
|
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
|
||||||
from backend.blocks.maths import CalculatorBlock, Operation
|
from backend.blocks.maths import CalculatorBlock, Operation
|
||||||
from backend.data import execution, graph
|
from backend.data import execution, graph
|
||||||
from backend.server.model import CreateGraph
|
from backend.server.model import CreateGraph
|
||||||
|
@ -31,7 +35,7 @@ async def execute_graph(
|
||||||
|
|
||||||
# --- Test adding new executions --- #
|
# --- Test adding new executions --- #
|
||||||
response = await agent_server.test_execute_graph(
|
response = await agent_server.test_execute_graph(
|
||||||
test_graph.id, input_data, test_user.id
|
test_graph.id, test_graph.version, input_data, test_user.id
|
||||||
)
|
)
|
||||||
graph_exec_id = response["id"]
|
graph_exec_id = response["id"]
|
||||||
logger.info(f"Created execution with ID: {graph_exec_id}")
|
logger.info(f"Created execution with ID: {graph_exec_id}")
|
||||||
|
@ -287,3 +291,255 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||||
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
assert exec_data.status == execution.ExecutionStatus.COMPLETED
|
||||||
assert exec_data.output_data == {"result": [9]}
|
assert exec_data.output_data == {"result": [9]}
|
||||||
logger.info("Completed test_static_input_link_on_graph")
|
logger.info("Completed test_static_input_link_on_graph")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="session")
|
||||||
|
async def test_execute_preset(server: SpinTestServer):
|
||||||
|
"""
|
||||||
|
Test executing a preset.
|
||||||
|
|
||||||
|
This test ensures that:
|
||||||
|
1. A preset can be successfully executed
|
||||||
|
2. The execution results are correct
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server (SpinTestServer): The test server instance.
|
||||||
|
"""
|
||||||
|
# Create test graph and user
|
||||||
|
nodes = [
|
||||||
|
graph.Node( # 0
|
||||||
|
block_id=AgentInputBlock().id,
|
||||||
|
input_default={"name": "dictionary"},
|
||||||
|
),
|
||||||
|
graph.Node( # 1
|
||||||
|
block_id=AgentInputBlock().id,
|
||||||
|
input_default={"name": "selected_value"},
|
||||||
|
),
|
||||||
|
graph.Node( # 2
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
|
||||||
|
),
|
||||||
|
graph.Node( # 3
|
||||||
|
block_id=FindInDictionaryBlock().id,
|
||||||
|
input_default={"key": "", "input": {}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
links = [
|
||||||
|
graph.Link(
|
||||||
|
source_id=nodes[0].id,
|
||||||
|
sink_id=nodes[2].id,
|
||||||
|
source_name="result",
|
||||||
|
sink_name="input",
|
||||||
|
),
|
||||||
|
graph.Link(
|
||||||
|
source_id=nodes[1].id,
|
||||||
|
sink_id=nodes[3].id,
|
||||||
|
source_name="result",
|
||||||
|
sink_name="key",
|
||||||
|
),
|
||||||
|
graph.Link(
|
||||||
|
source_id=nodes[2].id,
|
||||||
|
sink_id=nodes[3].id,
|
||||||
|
source_name="output",
|
||||||
|
sink_name="input",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
test_graph = graph.Graph(
|
||||||
|
name="TestGraph",
|
||||||
|
description="Test graph",
|
||||||
|
nodes=nodes,
|
||||||
|
links=links,
|
||||||
|
)
|
||||||
|
test_user = await create_test_user()
|
||||||
|
test_graph = await create_graph(server, test_graph, test_user)
|
||||||
|
|
||||||
|
# Create preset with initial values
|
||||||
|
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||||
|
name="Test Preset With Clash",
|
||||||
|
description="Test preset with clashing input values",
|
||||||
|
agent_id=test_graph.id,
|
||||||
|
agent_version=test_graph.version,
|
||||||
|
inputs={
|
||||||
|
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||||
|
"selected_value": "key2",
|
||||||
|
},
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||||
|
|
||||||
|
# Execute preset with overriding values
|
||||||
|
result = await server.agent_server.test_execute_preset(
|
||||||
|
graph_id=test_graph.id,
|
||||||
|
graph_version=test_graph.version,
|
||||||
|
preset_id=created_preset.id,
|
||||||
|
node_input={},
|
||||||
|
user_id=test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify execution
|
||||||
|
assert result is not None
|
||||||
|
graph_exec_id = result["id"]
|
||||||
|
|
||||||
|
# Wait for execution to complete
|
||||||
|
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||||
|
assert len(executions) == 4
|
||||||
|
|
||||||
|
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||||
|
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||||
|
assert executions[3].status == execution.ExecutionStatus.COMPLETED
|
||||||
|
assert executions[3].output_data == {"output": ["World"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="session")
|
||||||
|
async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||||
|
"""
|
||||||
|
Test executing a preset with clashing input data.
|
||||||
|
"""
|
||||||
|
# Create test graph and user
|
||||||
|
nodes = [
|
||||||
|
graph.Node( # 0
|
||||||
|
block_id=AgentInputBlock().id,
|
||||||
|
input_default={"name": "dictionary"},
|
||||||
|
),
|
||||||
|
graph.Node( # 1
|
||||||
|
block_id=AgentInputBlock().id,
|
||||||
|
input_default={"name": "selected_value"},
|
||||||
|
),
|
||||||
|
graph.Node( # 2
|
||||||
|
block_id=StoreValueBlock().id,
|
||||||
|
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
|
||||||
|
),
|
||||||
|
graph.Node( # 3
|
||||||
|
block_id=FindInDictionaryBlock().id,
|
||||||
|
input_default={"key": "", "input": {}},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
links = [
|
||||||
|
graph.Link(
|
||||||
|
source_id=nodes[0].id,
|
||||||
|
sink_id=nodes[2].id,
|
||||||
|
source_name="result",
|
||||||
|
sink_name="input",
|
||||||
|
),
|
||||||
|
graph.Link(
|
||||||
|
source_id=nodes[1].id,
|
||||||
|
sink_id=nodes[3].id,
|
||||||
|
source_name="result",
|
||||||
|
sink_name="key",
|
||||||
|
),
|
||||||
|
graph.Link(
|
||||||
|
source_id=nodes[2].id,
|
||||||
|
sink_id=nodes[3].id,
|
||||||
|
source_name="output",
|
||||||
|
sink_name="input",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
test_graph = graph.Graph(
|
||||||
|
name="TestGraph",
|
||||||
|
description="Test graph",
|
||||||
|
nodes=nodes,
|
||||||
|
links=links,
|
||||||
|
)
|
||||||
|
test_user = await create_test_user()
|
||||||
|
test_graph = await create_graph(server, test_graph, test_user)
|
||||||
|
|
||||||
|
# Create preset with initial values
|
||||||
|
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
|
||||||
|
name="Test Preset With Clash",
|
||||||
|
description="Test preset with clashing input values",
|
||||||
|
agent_id=test_graph.id,
|
||||||
|
agent_version=test_graph.version,
|
||||||
|
inputs={
|
||||||
|
"dictionary": {"key1": "Hello", "key2": "World"},
|
||||||
|
"selected_value": "key2",
|
||||||
|
},
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
|
||||||
|
|
||||||
|
# Execute preset with overriding values
|
||||||
|
result = await server.agent_server.test_execute_preset(
|
||||||
|
graph_id=test_graph.id,
|
||||||
|
graph_version=test_graph.version,
|
||||||
|
preset_id=created_preset.id,
|
||||||
|
node_input={"selected_value": "key1"},
|
||||||
|
user_id=test_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify execution
|
||||||
|
assert result is not None
|
||||||
|
graph_exec_id = result["id"]
|
||||||
|
|
||||||
|
# Wait for execution to complete
|
||||||
|
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||||
|
assert len(executions) == 4
|
||||||
|
|
||||||
|
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||||
|
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||||
|
assert executions[3].status == execution.ExecutionStatus.COMPLETED
|
||||||
|
assert executions[3].output_data == {"output": ["Hello"]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="session")
|
||||||
|
async def test_store_listing_graph(server: SpinTestServer):
|
||||||
|
logger.info("Starting test_agent_execution")
|
||||||
|
test_user = await create_test_user()
|
||||||
|
test_graph = await create_graph(server, create_test_graph(), test_user)
|
||||||
|
|
||||||
|
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
|
||||||
|
agent_id=test_graph.id,
|
||||||
|
agent_version=test_graph.version,
|
||||||
|
slug="test-slug",
|
||||||
|
name="Test name",
|
||||||
|
sub_heading="Test sub heading",
|
||||||
|
video_url=None,
|
||||||
|
image_urls=[],
|
||||||
|
description="Test description",
|
||||||
|
categories=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
store_listing = await server.agent_server.test_create_store_listing(
|
||||||
|
store_submission_request, test_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(store_listing, fastapi.responses.JSONResponse):
|
||||||
|
assert False, "Failed to create store listing"
|
||||||
|
|
||||||
|
slv_id = (
|
||||||
|
store_listing.store_listing_version_id
|
||||||
|
if store_listing.store_listing_version_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert slv_id is not None
|
||||||
|
|
||||||
|
admin = autogpt_libs.auth.models.User(
|
||||||
|
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
|
||||||
|
role="admin",
|
||||||
|
email="admin@example.com",
|
||||||
|
phone_number="1234567890",
|
||||||
|
)
|
||||||
|
await server.agent_server.test_review_store_listing(
|
||||||
|
backend.server.v2.store.model.ReviewSubmissionRequest(
|
||||||
|
store_listing_version_id=slv_id,
|
||||||
|
isApproved=True,
|
||||||
|
comments="Test comments",
|
||||||
|
),
|
||||||
|
admin,
|
||||||
|
)
|
||||||
|
|
||||||
|
alt_test_user = await create_test_user(alt_user=True)
|
||||||
|
|
||||||
|
data = {"input_1": "Hello", "input_2": "World"}
|
||||||
|
graph_exec_id = await execute_graph(
|
||||||
|
server.agent_server,
|
||||||
|
test_graph,
|
||||||
|
alt_test_user,
|
||||||
|
data,
|
||||||
|
4,
|
||||||
|
)
|
||||||
|
|
||||||
|
await assert_sample_graph_executions(
|
||||||
|
server.agent_server, test_graph, alt_test_user, graph_exec_id
|
||||||
|
)
|
||||||
|
logger.info("Completed test_agent_execution")
|
||||||
|
|
|
@ -140,10 +140,10 @@ async def main():
|
||||||
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
|
||||||
for user in users:
|
for user in users:
|
||||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||||
for _ in range(num_agents): # Create 1 UserAgent per user
|
for _ in range(num_agents): # Create 1 LibraryAgent per user
|
||||||
graph = random.choice(agent_graphs)
|
graph = random.choice(agent_graphs)
|
||||||
preset = random.choice(agent_presets)
|
preset = random.choice(agent_presets)
|
||||||
user_agent = await db.useragent.create(
|
user_agent = await db.libraryagent.create(
|
||||||
data={
|
data={
|
||||||
"userId": user.id,
|
"userId": user.id,
|
||||||
"agentId": graph.id,
|
"agentId": graph.id,
|
||||||
|
|
|
@ -123,14 +123,22 @@ export default function PrivatePage() {
|
||||||
|
|
||||||
const allCredentials = providers
|
const allCredentials = providers
|
||||||
? Object.values(providers).flatMap((provider) =>
|
? Object.values(providers).flatMap((provider) =>
|
||||||
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
|
[
|
||||||
|
...provider.savedOAuthCredentials,
|
||||||
|
...provider.savedApiKeys,
|
||||||
|
...provider.savedUserPasswordCredentials,
|
||||||
|
]
|
||||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||||
.map((credentials) => ({
|
.map((credentials) => ({
|
||||||
...credentials,
|
...credentials,
|
||||||
provider: provider.provider,
|
provider: provider.provider,
|
||||||
providerName: provider.providerName,
|
providerName: provider.providerName,
|
||||||
ProviderIcon: providerIcons[provider.provider],
|
ProviderIcon: providerIcons[provider.provider],
|
||||||
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
|
TypeIcon: {
|
||||||
|
oauth2: IconUser,
|
||||||
|
api_key: IconKey,
|
||||||
|
user_password: IconKey,
|
||||||
|
}[credentials.type],
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
: [];
|
: [];
|
||||||
|
@ -175,6 +183,7 @@ export default function PrivatePage() {
|
||||||
{
|
{
|
||||||
oauth2: "OAuth2 credentials",
|
oauth2: "OAuth2 credentials",
|
||||||
api_key: "API key",
|
api_key: "API key",
|
||||||
|
user_password: "User password",
|
||||||
}[cred.type]
|
}[cred.type]
|
||||||
}{" "}
|
}{" "}
|
||||||
- <code>{cred.id}</code>
|
- <code>{cred.id}</code>
|
||||||
|
|
|
@ -123,14 +123,22 @@ export default function PrivatePage() {
|
||||||
|
|
||||||
const allCredentials = providers
|
const allCredentials = providers
|
||||||
? Object.values(providers).flatMap((provider) =>
|
? Object.values(providers).flatMap((provider) =>
|
||||||
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
|
[
|
||||||
|
...provider.savedOAuthCredentials,
|
||||||
|
...provider.savedApiKeys,
|
||||||
|
...provider.savedUserPasswordCredentials,
|
||||||
|
]
|
||||||
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
.filter((cred) => !hiddenCredentials.includes(cred.id))
|
||||||
.map((credentials) => ({
|
.map((credentials) => ({
|
||||||
...credentials,
|
...credentials,
|
||||||
provider: provider.provider,
|
provider: provider.provider,
|
||||||
providerName: provider.providerName,
|
providerName: provider.providerName,
|
||||||
ProviderIcon: providerIcons[provider.provider],
|
ProviderIcon: providerIcons[provider.provider],
|
||||||
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
|
TypeIcon: {
|
||||||
|
oauth2: IconUser,
|
||||||
|
api_key: IconKey,
|
||||||
|
user_password: IconKey,
|
||||||
|
}[credentials.type],
|
||||||
})),
|
})),
|
||||||
)
|
)
|
||||||
: [];
|
: [];
|
||||||
|
|
|
@ -73,7 +73,9 @@ export const providerIcons: Record<
|
||||||
open_router: fallbackIcon,
|
open_router: fallbackIcon,
|
||||||
pinecone: fallbackIcon,
|
pinecone: fallbackIcon,
|
||||||
slant3d: fallbackIcon,
|
slant3d: fallbackIcon,
|
||||||
|
smtp: fallbackIcon,
|
||||||
replicate: fallbackIcon,
|
replicate: fallbackIcon,
|
||||||
|
reddit: fallbackIcon,
|
||||||
fal: fallbackIcon,
|
fal: fallbackIcon,
|
||||||
revid: fallbackIcon,
|
revid: fallbackIcon,
|
||||||
twitter: FaTwitter,
|
twitter: FaTwitter,
|
||||||
|
@ -105,6 +107,10 @@ export const CredentialsInput: FC<{
|
||||||
const credentials = useCredentials(selfKey);
|
const credentials = useCredentials(selfKey);
|
||||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||||
useState(false);
|
useState(false);
|
||||||
|
const [
|
||||||
|
isUserPasswordCredentialsModalOpen,
|
||||||
|
setUserPasswordCredentialsModalOpen,
|
||||||
|
] = useState(false);
|
||||||
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
||||||
const [oAuthPopupController, setOAuthPopupController] =
|
const [oAuthPopupController, setOAuthPopupController] =
|
||||||
useState<AbortController | null>(null);
|
useState<AbortController | null>(null);
|
||||||
|
@ -120,8 +126,10 @@ export const CredentialsInput: FC<{
|
||||||
providerName,
|
providerName,
|
||||||
supportsApiKey,
|
supportsApiKey,
|
||||||
supportsOAuth2,
|
supportsOAuth2,
|
||||||
|
supportsUserPassword,
|
||||||
savedApiKeys,
|
savedApiKeys,
|
||||||
savedOAuthCredentials,
|
savedOAuthCredentials,
|
||||||
|
savedUserPasswordCredentials,
|
||||||
oAuthCallback,
|
oAuthCallback,
|
||||||
} = credentials;
|
} = credentials;
|
||||||
|
|
||||||
|
@ -235,6 +243,17 @@ export const CredentialsInput: FC<{
|
||||||
providerName={providerName}
|
providerName={providerName}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{supportsUserPassword && (
|
||||||
|
<UserPasswordCredentialsModal
|
||||||
|
credentialsFieldName={selfKey}
|
||||||
|
open={isUserPasswordCredentialsModalOpen}
|
||||||
|
onClose={() => setUserPasswordCredentialsModalOpen(false)}
|
||||||
|
onCredentialsCreate={(creds) => {
|
||||||
|
onSelectCredentials(creds);
|
||||||
|
setUserPasswordCredentialsModalOpen(false);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -243,13 +262,18 @@ export const CredentialsInput: FC<{
|
||||||
selectedCredentials &&
|
selectedCredentials &&
|
||||||
!savedApiKeys
|
!savedApiKeys
|
||||||
.concat(savedOAuthCredentials)
|
.concat(savedOAuthCredentials)
|
||||||
|
.concat(savedUserPasswordCredentials)
|
||||||
.some((c) => c.id === selectedCredentials.id)
|
.some((c) => c.id === selectedCredentials.id)
|
||||||
) {
|
) {
|
||||||
onSelectCredentials(undefined);
|
onSelectCredentials(undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
// No saved credentials yet
|
// No saved credentials yet
|
||||||
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
|
if (
|
||||||
|
savedApiKeys.length === 0 &&
|
||||||
|
savedOAuthCredentials.length === 0 &&
|
||||||
|
savedUserPasswordCredentials.length === 0
|
||||||
|
) {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="mb-2 flex gap-1">
|
<div className="mb-2 flex gap-1">
|
||||||
|
@ -271,6 +295,12 @@ export const CredentialsInput: FC<{
|
||||||
Enter API key
|
Enter API key
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
|
{supportsUserPassword && (
|
||||||
|
<Button onClick={() => setUserPasswordCredentialsModalOpen(true)}>
|
||||||
|
<ProviderIcon className="mr-2 h-4 w-4" />
|
||||||
|
Enter user password
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
{modals}
|
{modals}
|
||||||
{oAuthError && (
|
{oAuthError && (
|
||||||
|
@ -280,12 +310,29 @@ export const CredentialsInput: FC<{
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const singleCredential =
|
const getCredentialCounts = () => ({
|
||||||
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0
|
apiKeys: savedApiKeys.length,
|
||||||
? savedApiKeys[0]
|
oauth: savedOAuthCredentials.length,
|
||||||
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0
|
userPass: savedUserPasswordCredentials.length,
|
||||||
? savedOAuthCredentials[0]
|
});
|
||||||
: null;
|
|
||||||
|
const getSingleCredential = () => {
|
||||||
|
const counts = getCredentialCounts();
|
||||||
|
const totalCredentials = Object.values(counts).reduce(
|
||||||
|
(sum, count) => sum + count,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (totalCredentials !== 1) return null;
|
||||||
|
|
||||||
|
if (counts.apiKeys === 1) return savedApiKeys[0];
|
||||||
|
if (counts.oauth === 1) return savedOAuthCredentials[0];
|
||||||
|
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const singleCredential = getSingleCredential();
|
||||||
|
|
||||||
if (singleCredential) {
|
if (singleCredential) {
|
||||||
if (!selectedCredentials) {
|
if (!selectedCredentials) {
|
||||||
|
@ -309,6 +356,7 @@ export const CredentialsInput: FC<{
|
||||||
} else {
|
} else {
|
||||||
const selectedCreds = savedApiKeys
|
const selectedCreds = savedApiKeys
|
||||||
.concat(savedOAuthCredentials)
|
.concat(savedOAuthCredentials)
|
||||||
|
.concat(savedUserPasswordCredentials)
|
||||||
.find((c) => c.id == newValue)!;
|
.find((c) => c.id == newValue)!;
|
||||||
|
|
||||||
onSelectCredentials({
|
onSelectCredentials({
|
||||||
|
@ -347,6 +395,13 @@ export const CredentialsInput: FC<{
|
||||||
{credentials.title}
|
{credentials.title}
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
))}
|
))}
|
||||||
|
{savedUserPasswordCredentials.map((credentials, index) => (
|
||||||
|
<SelectItem key={index} value={credentials.id}>
|
||||||
|
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||||
|
<IconUserPlus className="mr-1.5 inline" />
|
||||||
|
{credentials.title}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
<SelectSeparator />
|
<SelectSeparator />
|
||||||
{supportsOAuth2 && (
|
{supportsOAuth2 && (
|
||||||
<SelectItem value="sign-in">
|
<SelectItem value="sign-in">
|
||||||
|
@ -360,6 +415,12 @@ export const CredentialsInput: FC<{
|
||||||
Add new API key
|
Add new API key
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
)}
|
)}
|
||||||
|
{supportsUserPassword && (
|
||||||
|
<SelectItem value="add-user-password">
|
||||||
|
<IconUserPlus className="mr-1.5 inline" />
|
||||||
|
Add new user password
|
||||||
|
</SelectItem>
|
||||||
|
)}
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
{modals}
|
{modals}
|
||||||
|
@ -506,6 +567,128 @@ export const APIKeyCredentialsModal: FC<{
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const UserPasswordCredentialsModal: FC<{
|
||||||
|
credentialsFieldName: string;
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||||
|
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
|
||||||
|
const credentials = useCredentials(credentialsFieldName);
|
||||||
|
|
||||||
|
const formSchema = z.object({
|
||||||
|
username: z.string().min(1, "Username is required"),
|
||||||
|
password: z.string().min(1, "Password is required"),
|
||||||
|
title: z.string().min(1, "Name is required"),
|
||||||
|
});
|
||||||
|
|
||||||
|
const form = useForm<z.infer<typeof formSchema>>({
|
||||||
|
resolver: zodResolver(formSchema),
|
||||||
|
defaultValues: {
|
||||||
|
username: "",
|
||||||
|
password: "",
|
||||||
|
title: "",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (
|
||||||
|
!credentials ||
|
||||||
|
credentials.isLoading ||
|
||||||
|
!credentials.supportsUserPassword
|
||||||
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { schema, provider, providerName, createUserPasswordCredentials } =
|
||||||
|
credentials;
|
||||||
|
|
||||||
|
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||||
|
const newCredentials = await createUserPasswordCredentials({
|
||||||
|
username: values.username,
|
||||||
|
password: values.password,
|
||||||
|
title: values.title,
|
||||||
|
});
|
||||||
|
onCredentialsCreate({
|
||||||
|
provider,
|
||||||
|
id: newCredentials.id,
|
||||||
|
type: "user_password",
|
||||||
|
title: newCredentials.title,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog
|
||||||
|
open={open}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
if (!open) onClose();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>Add new user password for {providerName}</DialogTitle>
|
||||||
|
</DialogHeader>
|
||||||
|
<Form {...form}>
|
||||||
|
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="username"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Username</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
type="text"
|
||||||
|
placeholder="Enter username..."
|
||||||
|
{...field}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="password"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Password</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
placeholder="Enter password..."
|
||||||
|
{...field}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name="title"
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel>Name</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<Input
|
||||||
|
type="text"
|
||||||
|
placeholder="Enter a name for this user password..."
|
||||||
|
{...field}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<Button type="submit" className="w-full">
|
||||||
|
Save & use this user password
|
||||||
|
</Button>
|
||||||
|
</form>
|
||||||
|
</Form>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export const OAuth2FlowWaitingModal: FC<{
|
export const OAuth2FlowWaitingModal: FC<{
|
||||||
open: boolean;
|
open: boolean;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
|
|
|
@ -5,6 +5,7 @@ import {
|
||||||
CredentialsMetaResponse,
|
CredentialsMetaResponse,
|
||||||
CredentialsProviderName,
|
CredentialsProviderName,
|
||||||
PROVIDER_NAMES,
|
PROVIDER_NAMES,
|
||||||
|
UserPasswordCredentials,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
import { createContext, useCallback, useEffect, useState } from "react";
|
import { createContext, useCallback, useEffect, useState } from "react";
|
||||||
|
@ -20,10 +21,13 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||||
discord: "Discord",
|
discord: "Discord",
|
||||||
d_id: "D-ID",
|
d_id: "D-ID",
|
||||||
e2b: "E2B",
|
e2b: "E2B",
|
||||||
|
exa: "Exa",
|
||||||
|
fal: "FAL",
|
||||||
github: "GitHub",
|
github: "GitHub",
|
||||||
google: "Google",
|
google: "Google",
|
||||||
google_maps: "Google Maps",
|
google_maps: "Google Maps",
|
||||||
groq: "Groq",
|
groq: "Groq",
|
||||||
|
hubspot: "Hubspot",
|
||||||
ideogram: "Ideogram",
|
ideogram: "Ideogram",
|
||||||
jina: "Jina",
|
jina: "Jina",
|
||||||
medium: "Medium",
|
medium: "Medium",
|
||||||
|
@ -35,13 +39,12 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||||
open_router: "Open Router",
|
open_router: "Open Router",
|
||||||
pinecone: "Pinecone",
|
pinecone: "Pinecone",
|
||||||
slant3d: "Slant3D",
|
slant3d: "Slant3D",
|
||||||
|
smtp: "SMTP",
|
||||||
|
reddit: "Reddit",
|
||||||
replicate: "Replicate",
|
replicate: "Replicate",
|
||||||
fal: "FAL",
|
|
||||||
revid: "Rev.ID",
|
revid: "Rev.ID",
|
||||||
twitter: "Twitter",
|
twitter: "Twitter",
|
||||||
unreal_speech: "Unreal Speech",
|
unreal_speech: "Unreal Speech",
|
||||||
exa: "Exa",
|
|
||||||
hubspot: "Hubspot",
|
|
||||||
} as const;
|
} as const;
|
||||||
// --8<-- [end:CredentialsProviderNames]
|
// --8<-- [end:CredentialsProviderNames]
|
||||||
|
|
||||||
|
@ -50,11 +53,17 @@ type APIKeyCredentialsCreatable = Omit<
|
||||||
"id" | "provider" | "type"
|
"id" | "provider" | "type"
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
type UserPasswordCredentialsCreatable = Omit<
|
||||||
|
UserPasswordCredentials,
|
||||||
|
"id" | "provider" | "type"
|
||||||
|
>;
|
||||||
|
|
||||||
export type CredentialsProviderData = {
|
export type CredentialsProviderData = {
|
||||||
provider: CredentialsProviderName;
|
provider: CredentialsProviderName;
|
||||||
providerName: string;
|
providerName: string;
|
||||||
savedApiKeys: CredentialsMetaResponse[];
|
savedApiKeys: CredentialsMetaResponse[];
|
||||||
savedOAuthCredentials: CredentialsMetaResponse[];
|
savedOAuthCredentials: CredentialsMetaResponse[];
|
||||||
|
savedUserPasswordCredentials: CredentialsMetaResponse[];
|
||||||
oAuthCallback: (
|
oAuthCallback: (
|
||||||
code: string,
|
code: string,
|
||||||
state_token: string,
|
state_token: string,
|
||||||
|
@ -62,6 +71,9 @@ export type CredentialsProviderData = {
|
||||||
createAPIKeyCredentials: (
|
createAPIKeyCredentials: (
|
||||||
credentials: APIKeyCredentialsCreatable,
|
credentials: APIKeyCredentialsCreatable,
|
||||||
) => Promise<CredentialsMetaResponse>;
|
) => Promise<CredentialsMetaResponse>;
|
||||||
|
createUserPasswordCredentials: (
|
||||||
|
credentials: UserPasswordCredentialsCreatable,
|
||||||
|
) => Promise<CredentialsMetaResponse>;
|
||||||
deleteCredentials: (
|
deleteCredentials: (
|
||||||
id: string,
|
id: string,
|
||||||
force?: boolean,
|
force?: boolean,
|
||||||
|
@ -106,6 +118,11 @@ export default function CredentialsProvider({
|
||||||
...updatedProvider.savedOAuthCredentials,
|
...updatedProvider.savedOAuthCredentials,
|
||||||
credentials,
|
credentials,
|
||||||
];
|
];
|
||||||
|
} else if (credentials.type === "user_password") {
|
||||||
|
updatedProvider.savedUserPasswordCredentials = [
|
||||||
|
...updatedProvider.savedUserPasswordCredentials,
|
||||||
|
credentials,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -147,6 +164,22 @@ export default function CredentialsProvider({
|
||||||
[api, addCredentials],
|
[api, addCredentials],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/** Wraps `BackendAPI.createUserPasswordCredentials`, and adds the result to the internal credentials store. */
|
||||||
|
const createUserPasswordCredentials = useCallback(
|
||||||
|
async (
|
||||||
|
provider: CredentialsProviderName,
|
||||||
|
credentials: UserPasswordCredentialsCreatable,
|
||||||
|
): Promise<CredentialsMetaResponse> => {
|
||||||
|
const credsMeta = await api.createUserPasswordCredentials({
|
||||||
|
provider,
|
||||||
|
...credentials,
|
||||||
|
});
|
||||||
|
addCredentials(provider, credsMeta);
|
||||||
|
return credsMeta;
|
||||||
|
},
|
||||||
|
[api, addCredentials],
|
||||||
|
);
|
||||||
|
|
||||||
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
|
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
|
||||||
const deleteCredentials = useCallback(
|
const deleteCredentials = useCallback(
|
||||||
async (
|
async (
|
||||||
|
@ -171,7 +204,10 @@ export default function CredentialsProvider({
|
||||||
updatedProvider.savedOAuthCredentials.filter(
|
updatedProvider.savedOAuthCredentials.filter(
|
||||||
(cred) => cred.id !== id,
|
(cred) => cred.id !== id,
|
||||||
);
|
);
|
||||||
|
updatedProvider.savedUserPasswordCredentials =
|
||||||
|
updatedProvider.savedUserPasswordCredentials.filter(
|
||||||
|
(cred) => cred.id !== id,
|
||||||
|
);
|
||||||
return {
|
return {
|
||||||
...prev,
|
...prev,
|
||||||
[provider]: updatedProvider,
|
[provider]: updatedProvider,
|
||||||
|
@ -190,12 +226,18 @@ export default function CredentialsProvider({
|
||||||
const credentialsByProvider = response.reduce(
|
const credentialsByProvider = response.reduce(
|
||||||
(acc, cred) => {
|
(acc, cred) => {
|
||||||
if (!acc[cred.provider]) {
|
if (!acc[cred.provider]) {
|
||||||
acc[cred.provider] = { oauthCreds: [], apiKeys: [] };
|
acc[cred.provider] = {
|
||||||
|
oauthCreds: [],
|
||||||
|
apiKeys: [],
|
||||||
|
userPasswordCreds: [],
|
||||||
|
};
|
||||||
}
|
}
|
||||||
if (cred.type === "oauth2") {
|
if (cred.type === "oauth2") {
|
||||||
acc[cred.provider].oauthCreds.push(cred);
|
acc[cred.provider].oauthCreds.push(cred);
|
||||||
} else if (cred.type === "api_key") {
|
} else if (cred.type === "api_key") {
|
||||||
acc[cred.provider].apiKeys.push(cred);
|
acc[cred.provider].apiKeys.push(cred);
|
||||||
|
} else if (cred.type === "user_password") {
|
||||||
|
acc[cred.provider].userPasswordCreds.push(cred);
|
||||||
}
|
}
|
||||||
return acc;
|
return acc;
|
||||||
},
|
},
|
||||||
|
@ -204,6 +246,7 @@ export default function CredentialsProvider({
|
||||||
{
|
{
|
||||||
oauthCreds: CredentialsMetaResponse[];
|
oauthCreds: CredentialsMetaResponse[];
|
||||||
apiKeys: CredentialsMetaResponse[];
|
apiKeys: CredentialsMetaResponse[];
|
||||||
|
userPasswordCreds: CredentialsMetaResponse[];
|
||||||
}
|
}
|
||||||
>,
|
>,
|
||||||
);
|
);
|
||||||
|
@ -220,6 +263,8 @@ export default function CredentialsProvider({
|
||||||
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
|
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
|
||||||
savedOAuthCredentials:
|
savedOAuthCredentials:
|
||||||
credentialsByProvider[provider]?.oauthCreds ?? [],
|
credentialsByProvider[provider]?.oauthCreds ?? [],
|
||||||
|
savedUserPasswordCredentials:
|
||||||
|
credentialsByProvider[provider]?.userPasswordCreds ?? [],
|
||||||
oAuthCallback: (code: string, state_token: string) =>
|
oAuthCallback: (code: string, state_token: string) =>
|
||||||
oAuthCallback(
|
oAuthCallback(
|
||||||
provider as CredentialsProviderName,
|
provider as CredentialsProviderName,
|
||||||
|
@ -233,6 +278,13 @@ export default function CredentialsProvider({
|
||||||
provider as CredentialsProviderName,
|
provider as CredentialsProviderName,
|
||||||
credentials,
|
credentials,
|
||||||
),
|
),
|
||||||
|
createUserPasswordCredentials: (
|
||||||
|
credentials: UserPasswordCredentialsCreatable,
|
||||||
|
) =>
|
||||||
|
createUserPasswordCredentials(
|
||||||
|
provider as CredentialsProviderName,
|
||||||
|
credentials,
|
||||||
|
),
|
||||||
deleteCredentials: (id: string, force: boolean = false) =>
|
deleteCredentials: (id: string, force: boolean = false) =>
|
||||||
deleteCredentials(
|
deleteCredentials(
|
||||||
provider as CredentialsProviderName,
|
provider as CredentialsProviderName,
|
||||||
|
@ -245,7 +297,13 @@ export default function CredentialsProvider({
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}, [api, createAPIKeyCredentials, deleteCredentials, oAuthCallback]);
|
}, [
|
||||||
|
api,
|
||||||
|
createAPIKeyCredentials,
|
||||||
|
createUserPasswordCredentials,
|
||||||
|
deleteCredentials,
|
||||||
|
oAuthCallback,
|
||||||
|
]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<CredentialsProvidersContext.Provider value={providers}>
|
<CredentialsProvidersContext.Provider value={providers}>
|
||||||
|
|
|
@ -17,12 +17,14 @@ export type CredentialsData =
|
||||||
schema: BlockIOCredentialsSubSchema;
|
schema: BlockIOCredentialsSubSchema;
|
||||||
supportsApiKey: boolean;
|
supportsApiKey: boolean;
|
||||||
supportsOAuth2: boolean;
|
supportsOAuth2: boolean;
|
||||||
|
supportsUserPassword: boolean;
|
||||||
isLoading: true;
|
isLoading: true;
|
||||||
}
|
}
|
||||||
| (CredentialsProviderData & {
|
| (CredentialsProviderData & {
|
||||||
schema: BlockIOCredentialsSubSchema;
|
schema: BlockIOCredentialsSubSchema;
|
||||||
supportsApiKey: boolean;
|
supportsApiKey: boolean;
|
||||||
supportsOAuth2: boolean;
|
supportsOAuth2: boolean;
|
||||||
|
supportsUserPassword: boolean;
|
||||||
isLoading: false;
|
isLoading: false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -72,6 +74,8 @@ export default function useCredentials(
|
||||||
const supportsApiKey =
|
const supportsApiKey =
|
||||||
credentialsSchema.credentials_types.includes("api_key");
|
credentialsSchema.credentials_types.includes("api_key");
|
||||||
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
||||||
|
const supportsUserPassword =
|
||||||
|
credentialsSchema.credentials_types.includes("user_password");
|
||||||
|
|
||||||
// No provider means maybe it's still loading
|
// No provider means maybe it's still loading
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
|
@ -93,13 +97,17 @@ export default function useCredentials(
|
||||||
)
|
)
|
||||||
: provider.savedOAuthCredentials;
|
: provider.savedOAuthCredentials;
|
||||||
|
|
||||||
|
const savedUserPasswordCredentials = provider.savedUserPasswordCredentials;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
provider: providerName,
|
provider: providerName,
|
||||||
schema: credentialsSchema,
|
schema: credentialsSchema,
|
||||||
supportsApiKey,
|
supportsApiKey,
|
||||||
supportsOAuth2,
|
supportsOAuth2,
|
||||||
|
supportsUserPassword,
|
||||||
savedOAuthCredentials,
|
savedOAuthCredentials,
|
||||||
|
savedUserPasswordCredentials,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import {
|
||||||
GraphUpdateable,
|
GraphUpdateable,
|
||||||
NodeExecutionResult,
|
NodeExecutionResult,
|
||||||
MyAgentsResponse,
|
MyAgentsResponse,
|
||||||
OAuth2Credentials,
|
|
||||||
ProfileDetails,
|
ProfileDetails,
|
||||||
User,
|
User,
|
||||||
StoreAgentsResponse,
|
StoreAgentsResponse,
|
||||||
|
@ -29,6 +28,8 @@ import {
|
||||||
StoreReview,
|
StoreReview,
|
||||||
ScheduleCreatable,
|
ScheduleCreatable,
|
||||||
Schedule,
|
Schedule,
|
||||||
|
UserPasswordCredentials,
|
||||||
|
Credentials,
|
||||||
APIKeyPermission,
|
APIKeyPermission,
|
||||||
CreateAPIKeyResponse,
|
CreateAPIKeyResponse,
|
||||||
APIKey,
|
APIKey,
|
||||||
|
@ -191,7 +192,17 @@ export default class BackendAPI {
|
||||||
return this._request(
|
return this._request(
|
||||||
"POST",
|
"POST",
|
||||||
`/integrations/${credentials.provider}/credentials`,
|
`/integrations/${credentials.provider}/credentials`,
|
||||||
credentials,
|
{ ...credentials, type: "api_key" },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
createUserPasswordCredentials(
|
||||||
|
credentials: Omit<UserPasswordCredentials, "id" | "type">,
|
||||||
|
): Promise<UserPasswordCredentials> {
|
||||||
|
return this._request(
|
||||||
|
"POST",
|
||||||
|
`/integrations/${credentials.provider}/credentials`,
|
||||||
|
{ ...credentials, type: "user_password" },
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,10 +214,7 @@ export default class BackendAPI {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
getCredentials(
|
getCredentials(provider: string, id: string): Promise<Credentials> {
|
||||||
provider: string,
|
|
||||||
id: string,
|
|
||||||
): Promise<APIKeyCredentials | OAuth2Credentials> {
|
|
||||||
return this._get(`/integrations/${provider}/credentials/${id}`);
|
return this._get(`/integrations/${provider}/credentials/${id}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,12 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
|
||||||
default?: boolean;
|
default?: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CredentialsType = "api_key" | "oauth2";
|
export type CredentialsType = "api_key" | "oauth2" | "user_password";
|
||||||
|
|
||||||
|
export type Credentials =
|
||||||
|
| APIKeyCredentials
|
||||||
|
| OAuth2Credentials
|
||||||
|
| UserPasswordCredentials;
|
||||||
|
|
||||||
// --8<-- [start:BlockIOCredentialsSubSchema]
|
// --8<-- [start:BlockIOCredentialsSubSchema]
|
||||||
export const PROVIDER_NAMES = {
|
export const PROVIDER_NAMES = {
|
||||||
|
@ -105,10 +110,13 @@ export const PROVIDER_NAMES = {
|
||||||
D_ID: "d_id",
|
D_ID: "d_id",
|
||||||
DISCORD: "discord",
|
DISCORD: "discord",
|
||||||
E2B: "e2b",
|
E2B: "e2b",
|
||||||
|
EXA: "exa",
|
||||||
|
FAL: "fal",
|
||||||
GITHUB: "github",
|
GITHUB: "github",
|
||||||
GOOGLE: "google",
|
GOOGLE: "google",
|
||||||
GOOGLE_MAPS: "google_maps",
|
GOOGLE_MAPS: "google_maps",
|
||||||
GROQ: "groq",
|
GROQ: "groq",
|
||||||
|
HUBSPOT: "hubspot",
|
||||||
IDEOGRAM: "ideogram",
|
IDEOGRAM: "ideogram",
|
||||||
JINA: "jina",
|
JINA: "jina",
|
||||||
MEDIUM: "medium",
|
MEDIUM: "medium",
|
||||||
|
@ -120,13 +128,12 @@ export const PROVIDER_NAMES = {
|
||||||
OPEN_ROUTER: "open_router",
|
OPEN_ROUTER: "open_router",
|
||||||
PINECONE: "pinecone",
|
PINECONE: "pinecone",
|
||||||
SLANT3D: "slant3d",
|
SLANT3D: "slant3d",
|
||||||
|
SMTP: "smtp",
|
||||||
|
TWITTER: "twitter",
|
||||||
REPLICATE: "replicate",
|
REPLICATE: "replicate",
|
||||||
FAL: "fal",
|
REDDIT: "reddit",
|
||||||
REVID: "revid",
|
REVID: "revid",
|
||||||
UNREAL_SPEECH: "unreal_speech",
|
UNREAL_SPEECH: "unreal_speech",
|
||||||
EXA: "exa",
|
|
||||||
HUBSPOT: "hubspot",
|
|
||||||
TWITTER: "twitter",
|
|
||||||
} as const;
|
} as const;
|
||||||
// --8<-- [end:BlockIOCredentialsSubSchema]
|
// --8<-- [end:BlockIOCredentialsSubSchema]
|
||||||
|
|
||||||
|
@ -322,8 +329,15 @@ export type APIKeyCredentials = BaseCredentials & {
|
||||||
expires_at?: number;
|
expires_at?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type UserPasswordCredentials = BaseCredentials & {
|
||||||
|
type: "user_password";
|
||||||
|
title: string;
|
||||||
|
username: string;
|
||||||
|
password: string;
|
||||||
|
};
|
||||||
|
|
||||||
/* Mirror of backend/data/integrations.py:Webhook */
|
/* Mirror of backend/data/integrations.py:Webhook */
|
||||||
type Webhook = {
|
export type Webhook = {
|
||||||
id: string;
|
id: string;
|
||||||
url: string;
|
url: string;
|
||||||
provider: CredentialsProviderName;
|
provider: CredentialsProviderName;
|
||||||
|
|
Loading…
Reference in New Issue