Compare commits

...

33 Commits

Author SHA1 Message Date
Nicholas Tindle 2fd8c8d261
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2025-01-10 15:11:54 -06:00
Swifty fd6f28fa57
feature(platform): Implement library add, update, remove, archive functionality (#9218)
### Changes 🏗️

1. **Core Features**:
   - Add agents to the user's library.
   - Update library agents (auto-update, favorite, archive, delete).
   - Paginate library agents and presets.
   - Execute graphs using presets.

2. **Refactoring**:
   - Replaced `UserAgent` with `LibraryAgent`.
   - Separated routes for agents and presets.

3. **Schema Changes**:
- Added `LibraryAgent` table with fields like `isArchived`, `isDeleted`,
etc.
   - Soft delete functionality for `AgentPreset`.

4. **Testing**:
   - Updated tests for `LibraryAgent` operations.
   - Added edge case tests for deletion, archiving, and pagination.

5. **Database Migrations**:
   - Migration to drop `UserAgent` and add `LibraryAgent`.
   - Added fields for soft deletion and auto-update.


Note this includes the changes from the following PR's to avoid merge
conflicts with them:

#9179 
#9211

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-01-10 13:02:53 +01:00
Swifty 4b17cc9963
feat(backend): Add Support for Managing Agent Presets with Pagination and Soft Delete (#9211)
#### Summary
- **New Models**: Added `LibraryAgentPreset`,
`LibraryAgentPresetResponse`, `Pagination`, and
`CreateLibraryAgentPresetRequest`.
- **Database Changes**:
  - Added `isDeleted` column in `AgentPreset` for soft delete.
  - CRUD operations for `AgentPreset`:
    - `get_presets` with pagination.
    - `get_preset` by ID.
    - `create_or_update_preset` for upsert.
    - `delete_preset` to soft delete.
- **API Routes**:
  - `GET /presets`: Fetch paginated presets.
  - `GET /presets/{preset_id}`: Fetch a single preset.
  - `POST /presets`: Create a preset.
  - `PUT /presets/{preset_id}`: Update a preset.
  - `DELETE /presets/{preset_id}`: Soft delete a preset.
- **Tests**:
  - Coverage for models, CRUD operations, and pagination.
- **Migration**:
  - Added `isDeleted` field to support soft delete.

#### Review Notes
- Validate migration scripts and test coverage.
- Ensure API aligns with project standards.

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2025-01-10 12:57:35 +01:00
Swifty 00bb7c67b3
feature(backend): Add ability to execute store agents without agent ownership (#9179)
### Description

This PR enables the execution of store agents even if they are not owned
by the user. Key changes include handling store-listed agents in the
`get_graph` logic, improving execution flow, and ensuring
version-specific handling. These updates support more flexible agent
execution.

### Changes 🏗️

- **Graph Retrieval:** Updated `get_graph` to check store listings for
agents not owned by the user.
- **Version Handling:** Added `graph_version` to execution methods for
consistent version-specific execution.
- **Execution Flow:** Refactored `scheduler.py`, `rest_api.py`, and
other modules for clearer logic and better maintainability.
- **Testing:** Updated `test_manager.py` and other test cases to
validate execution of store-listed agents added test for accessing graph

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2025-01-10 12:39:06 +01:00
Nicholas Tindle 6b31356264
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2025-01-09 17:09:47 -06:00
Nicholas Tindle a88c865437
fix: lint 2025-01-09 16:05:56 -06:00
Nicholas Tindle 287aa819bb
feat: Test creds for smtp 2025-01-09 16:05:20 -06:00
Nicholas Tindle db21c6d4bc
fix: enable block 2025-01-09 15:32:28 -06:00
Nicholas Tindle 59dd75d016
feat: smtp block 2025-01-09 15:31:19 -06:00
Nicholas Tindle 38761f6706
fix: better docs 2025-01-08 17:30:59 -06:00
Nicholas Tindle 513e4eae4b
feat: add reddit oauth instructions 2025-01-08 16:57:32 -06:00
Nicholas Tindle fec9d348a0
fix: linting 2025-01-08 16:48:17 -06:00
Nicholas Tindle 75634e6155
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2025-01-08 16:46:34 -06:00
Nicholas Tindle 6cf77c264a
fix: do single credentials better 2025-01-08 16:42:09 -06:00
Nicholas Tindle 2b5c94d508
fix: re-enable the reddit blocks based on client id and secret 2025-01-06 18:42:20 -06:00
Nicholas Tindle 1c6b33d9fb
fix: merge changes 2025-01-06 18:30:56 -06:00
Nicholas Tindle d4692f33e2
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2025-01-06 18:25:59 -06:00
Nicholas Tindle d7a9563d49
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2025-01-06 17:26:04 -06:00
Nicholas Tindle 2ea61f8b65
fix: ingest url had type mismatch 2025-01-06 16:17:45 -06:00
Nicholas Tindle 2a5f3d167d
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2025-01-06 12:10:34 -06:00
Nicholas Tindle c0a5a01311
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2024-12-30 16:48:27 -06:00
Nicholas Tindle 0aee309f72
fix: tests 2024-12-20 15:35:33 -06:00
Nicholas Tindle 4c07f6c633
fix: testing? 2024-12-20 15:32:40 -06:00
Nicholas Tindle c39f27bcd4
fix: credential typing 2024-12-20 15:29:27 -06:00
Nicholas Tindle 35dcc6a2a1
fix: linting 2024-12-20 15:18:05 -06:00
Nicholas Tindle bef5637f29
fix: linting 2024-12-20 15:14:10 -06:00
Nicholas Tindle e933502cbd
fix: build out the credentials provider on the front and backend a bit more 2024-12-20 15:11:58 -06:00
Nicholas Tindle 5720225a75
fix: why aren't these fully a-z???? 2024-12-20 13:03:00 -06:00
Nicholas Tindle cb3808cb78
feat: fill out reddit required details + more user/pass stuff 2024-12-20 13:01:25 -06:00
Nicholas Tindle b6b97f10b8
fix: random c 2024-12-20 12:43:21 -06:00
Nicholas Tindle 0a905c6d66
Merge branch 'dev' into ntindle/open-2032-re-enable-getredditpostblock-sendemailblock 2024-12-20 12:36:20 -06:00
Nicholas Tindle 6b3f5b413f
feat: add more credentials changes as required by reddit block 2024-12-20 12:11:41 -06:00
Nicholas Tindle 8d79a62f61
feat: basic credentials fields added 2024-12-18 11:01:30 -06:00
46 changed files with 2101 additions and 488 deletions

View File

@ -82,10 +82,13 @@ GROQ_API_KEY=
OPEN_ROUTER_API_KEY= OPEN_ROUTER_API_KEY=
# Reddit # Reddit
# Reddit
# Go to https://www.reddit.com/prefs/apps and create a new app
# Choose "script" for the type
# Fill in the redirect uri as <your_frontend_url>/auth/integrations/oauth_callback, e.g. http://localhost:3000/auth/integrations/oauth_callback
REDDIT_CLIENT_ID= REDDIT_CLIENT_ID=
REDDIT_CLIENT_SECRET= REDDIT_CLIENT_SECRET=
REDDIT_USERNAME= REDDIT_USER_AGENT="AutoGPT:1.0 (by /u/autogpt)"
REDDIT_PASSWORD=
# Discord # Discord
DISCORD_BOT_TOKEN= DISCORD_BOT_TOKEN=

View File

@ -1,22 +1,53 @@
import smtplib import smtplib
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import Literal
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="smtp",
username=SecretStr("mock-smtp-username"),
password=SecretStr("mock-smtp-password"),
title="Mock SMTP credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
SMTPCredentials = UserPasswordCredentials
SMTPCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.SMTP],
Literal["user_password"],
]
class EmailCredentials(BaseModel): def SMTPCredentialsField() -> SMTPCredentialsInput:
return CredentialsField(
description="The SMTP integration requires a username and password.",
)
class SmtpConfig(BaseModel):
smtp_server: str = SchemaField( smtp_server: str = SchemaField(
default="smtp.gmail.com", description="SMTP server address" default="smtp.gmail.com", description="SMTP server address"
) )
smtp_port: int = SchemaField(default=25, description="SMTP port number") smtp_port: int = SchemaField(default=25, description="SMTP port number")
smtp_username: BlockSecret = SecretField(key="smtp_username")
smtp_password: BlockSecret = SecretField(key="smtp_password")
model_config = ConfigDict(title="Email Credentials") model_config = ConfigDict(title="SMTP Config")
class SendEmailBlock(Block): class SendEmailBlock(Block):
@ -30,10 +61,11 @@ class SendEmailBlock(Block):
body: str = SchemaField( body: str = SchemaField(
description="Body of the email", placeholder="Enter the email body" description="Body of the email", placeholder="Enter the email body"
) )
creds: EmailCredentials = SchemaField( config: SmtpConfig = SchemaField(
description="SMTP credentials", description="SMTP Config",
default=EmailCredentials(), default=SmtpConfig(),
) )
credentials: SMTPCredentialsInput = SMTPCredentialsField()
class Output(BlockSchema): class Output(BlockSchema):
status: str = SchemaField(description="Status of the email sending operation") status: str = SchemaField(description="Status of the email sending operation")
@ -43,7 +75,6 @@ class SendEmailBlock(Block):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
disabled=True,
id="4335878a-394e-4e67-adf2-919877ff49ae", id="4335878a-394e-4e67-adf2-919877ff49ae",
description="This block sends an email using the provided SMTP credentials.", description="This block sends an email using the provided SMTP credentials.",
categories={BlockCategory.OUTPUT}, categories={BlockCategory.OUTPUT},
@ -53,25 +84,29 @@ class SendEmailBlock(Block):
"to_email": "recipient@example.com", "to_email": "recipient@example.com",
"subject": "Test Email", "subject": "Test Email",
"body": "This is a test email.", "body": "This is a test email.",
"creds": { "config": {
"smtp_server": "smtp.gmail.com", "smtp_server": "smtp.gmail.com",
"smtp_port": 25, "smtp_port": 25,
"smtp_username": "your-email@gmail.com",
"smtp_password": "your-gmail-password",
}, },
"credentials": TEST_CREDENTIALS_INPUT,
}, },
test_credentials=TEST_CREDENTIALS,
test_output=[("status", "Email sent successfully")], test_output=[("status", "Email sent successfully")],
test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"}, test_mock={"send_email": lambda *args, **kwargs: "Email sent successfully"},
) )
@staticmethod @staticmethod
def send_email( def send_email(
creds: EmailCredentials, to_email: str, subject: str, body: str config: SmtpConfig,
to_email: str,
subject: str,
body: str,
credentials: SMTPCredentials,
) -> str: ) -> str:
smtp_server = creds.smtp_server smtp_server = config.smtp_server
smtp_port = creds.smtp_port smtp_port = config.smtp_port
smtp_username = creds.smtp_username.get_secret_value() smtp_username = credentials.username.get_secret_value()
smtp_password = creds.smtp_password.get_secret_value() smtp_password = credentials.password.get_secret_value()
msg = MIMEMultipart() msg = MIMEMultipart()
msg["From"] = smtp_username msg["From"] = smtp_username
@ -86,10 +121,13 @@ class SendEmailBlock(Block):
return "Email sent successfully" return "Email sent successfully"
def run(self, input_data: Input, **kwargs) -> BlockOutput: def run(
self, input_data: Input, *, credentials: SMTPCredentials, **kwargs
) -> BlockOutput:
yield "status", self.send_email( yield "status", self.send_email(
input_data.creds, config=input_data.config,
input_data.to_email, to_email=input_data.to_email,
input_data.subject, subject=input_data.subject,
input_data.body, body=input_data.body,
credentials=credentials,
) )

View File

@ -1,22 +1,48 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Iterator from typing import Iterator, Literal
import praw import praw
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
SchemaField,
UserPasswordCredentials,
)
from backend.integrations.providers import ProviderName
from backend.util.mock import MockObject from backend.util.mock import MockObject
from backend.util.settings import Settings
RedditCredentials = UserPasswordCredentials
RedditCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.REDDIT],
Literal["user_password"],
]
class RedditCredentials(BaseModel): def RedditCredentialsField() -> RedditCredentialsInput:
client_id: BlockSecret = SecretField(key="reddit_client_id") """Creates a Reddit credentials input on a block."""
client_secret: BlockSecret = SecretField(key="reddit_client_secret") return CredentialsField(
username: BlockSecret = SecretField(key="reddit_username") description="The Reddit integration requires a username and password.",
password: BlockSecret = SecretField(key="reddit_password") )
user_agent: str = "AutoGPT:1.0 (by /u/autogpt)"
model_config = ConfigDict(title="Reddit Credentials")
TEST_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="reddit",
username=SecretStr("mock-reddit-username"),
password=SecretStr("mock-reddit-password"),
title="Mock Reddit credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class RedditPost(BaseModel): class RedditPost(BaseModel):
@ -31,13 +57,16 @@ class RedditComment(BaseModel):
comment: str comment: str
settings = Settings()
def get_praw(creds: RedditCredentials) -> praw.Reddit: def get_praw(creds: RedditCredentials) -> praw.Reddit:
client = praw.Reddit( client = praw.Reddit(
client_id=creds.client_id.get_secret_value(), client_id=settings.secrets.reddit_client_id,
client_secret=creds.client_secret.get_secret_value(), client_secret=settings.secrets.reddit_client_secret,
username=creds.username.get_secret_value(), username=creds.username.get_secret_value(),
password=creds.password.get_secret_value(), password=creds.password.get_secret_value(),
user_agent=creds.user_agent, user_agent=settings.config.reddit_user_agent,
) )
me = client.user.me() me = client.user.me()
if not me: if not me:
@ -48,11 +77,11 @@ def get_praw(creds: RedditCredentials) -> praw.Reddit:
class GetRedditPostsBlock(Block): class GetRedditPostsBlock(Block):
class Input(BlockSchema): class Input(BlockSchema):
subreddit: str = SchemaField(description="Subreddit name") subreddit: str = SchemaField(
creds: RedditCredentials = SchemaField( description="Subreddit name, excluding the /r/ prefix",
description="Reddit credentials", default="writingprompts",
default=RedditCredentials(),
) )
credentials: RedditCredentialsInput = RedditCredentialsField()
last_minutes: int | None = SchemaField( last_minutes: int | None = SchemaField(
description="Post time to stop minutes ago while fetching posts", description="Post time to stop minutes ago while fetching posts",
default=None, default=None,
@ -70,20 +99,18 @@ class GetRedditPostsBlock(Block):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
disabled=True,
id="c6731acb-4285-4ee1-bc9b-03d0766c370f", id="c6731acb-4285-4ee1-bc9b-03d0766c370f",
description="This block fetches Reddit posts from a defined subreddit name.", description="This block fetches Reddit posts from a defined subreddit name.",
categories={BlockCategory.SOCIAL}, categories={BlockCategory.SOCIAL},
disabled=(
not settings.secrets.reddit_client_id
or not settings.secrets.reddit_client_secret
),
input_schema=GetRedditPostsBlock.Input, input_schema=GetRedditPostsBlock.Input,
output_schema=GetRedditPostsBlock.Output, output_schema=GetRedditPostsBlock.Output,
test_credentials=TEST_CREDENTIALS,
test_input={ test_input={
"creds": { "credentials": TEST_CREDENTIALS_INPUT,
"client_id": "client_id",
"client_secret": "client_secret",
"username": "username",
"password": "password",
"user_agent": "user_agent",
},
"subreddit": "subreddit", "subreddit": "subreddit",
"last_post": "id3", "last_post": "id3",
"post_limit": 2, "post_limit": 2,
@ -103,7 +130,7 @@ class GetRedditPostsBlock(Block):
), ),
], ],
test_mock={ test_mock={
"get_posts": lambda _: [ "get_posts": lambda input_data, credentials: [
MockObject(id="id1", title="title1", selftext="body1"), MockObject(id="id1", title="title1", selftext="body1"),
MockObject(id="id2", title="title2", selftext="body2"), MockObject(id="id2", title="title2", selftext="body2"),
MockObject(id="id3", title="title2", selftext="body2"), MockObject(id="id3", title="title2", selftext="body2"),
@ -112,14 +139,18 @@ class GetRedditPostsBlock(Block):
) )
@staticmethod @staticmethod
def get_posts(input_data: Input) -> Iterator[praw.reddit.Submission]: def get_posts(
client = get_praw(input_data.creds) input_data: Input, *, credentials: RedditCredentials
) -> Iterator[praw.reddit.Submission]:
client = get_praw(credentials)
subreddit = client.subreddit(input_data.subreddit) subreddit = client.subreddit(input_data.subreddit)
return subreddit.new(limit=input_data.post_limit or 10) return subreddit.new(limit=input_data.post_limit or 10)
def run(self, input_data: Input, **kwargs) -> BlockOutput: def run(
self, input_data: Input, *, credentials: RedditCredentials, **kwargs
) -> BlockOutput:
current_time = datetime.now(tz=timezone.utc) current_time = datetime.now(tz=timezone.utc)
for post in self.get_posts(input_data): for post in self.get_posts(input_data=input_data, credentials=credentials):
if input_data.last_minutes: if input_data.last_minutes:
post_datetime = datetime.fromtimestamp( post_datetime = datetime.fromtimestamp(
post.created_utc, tz=timezone.utc post.created_utc, tz=timezone.utc
@ -141,9 +172,7 @@ class GetRedditPostsBlock(Block):
class PostRedditCommentBlock(Block): class PostRedditCommentBlock(Block):
class Input(BlockSchema): class Input(BlockSchema):
creds: RedditCredentials = SchemaField( credentials: RedditCredentialsInput = RedditCredentialsField()
description="Reddit credentials", default=RedditCredentials()
)
data: RedditComment = SchemaField(description="Reddit comment") data: RedditComment = SchemaField(description="Reddit comment")
class Output(BlockSchema): class Output(BlockSchema):
@ -156,7 +185,15 @@ class PostRedditCommentBlock(Block):
categories={BlockCategory.SOCIAL}, categories={BlockCategory.SOCIAL},
input_schema=PostRedditCommentBlock.Input, input_schema=PostRedditCommentBlock.Input,
output_schema=PostRedditCommentBlock.Output, output_schema=PostRedditCommentBlock.Output,
test_input={"data": {"post_id": "id", "comment": "comment"}}, disabled=(
not settings.secrets.reddit_client_id
or not settings.secrets.reddit_client_secret
),
test_credentials=TEST_CREDENTIALS,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"data": {"post_id": "id", "comment": "comment"},
},
test_output=[("comment_id", "dummy_comment_id")], test_output=[("comment_id", "dummy_comment_id")],
test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"}, test_mock={"reply_post": lambda creds, comment: "dummy_comment_id"},
) )
@ -170,5 +207,7 @@ class PostRedditCommentBlock(Block):
raise ValueError("Failed to post comment.") raise ValueError("Failed to post comment.")
return new_comment.id return new_comment.id
def run(self, input_data: Input, **kwargs) -> BlockOutput: def run(
yield "comment_id", self.reply_post(input_data.creds, input_data.data) self, input_data: Input, *, credentials: RedditCredentials, **kwargs
) -> BlockOutput:
yield "comment_id", self.reply_post(credentials, input_data.data)

View File

@ -136,6 +136,7 @@ async def create_graph_execution(
graph_version: int, graph_version: int,
nodes_input: list[tuple[str, BlockInput]], nodes_input: list[tuple[str, BlockInput]],
user_id: str, user_id: str,
preset_id: str | None = None,
) -> tuple[str, list[ExecutionResult]]: ) -> tuple[str, list[ExecutionResult]]:
""" """
Create a new AgentGraphExecution record. Create a new AgentGraphExecution record.
@ -163,6 +164,7 @@ async def create_graph_execution(
] ]
}, },
"userId": user_id, "userId": user_id,
"agentPresetId": preset_id,
}, },
include=GRAPH_EXECUTION_INCLUDE, include=GRAPH_EXECUTION_INCLUDE,
) )

View File

@ -6,7 +6,13 @@ from datetime import datetime, timezone
from typing import Any, Literal, Optional, Type from typing import Any, Literal, Optional, Type
import prisma import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink from prisma.models import (
AgentGraph,
AgentGraphExecution,
AgentNode,
AgentNodeLink,
StoreListingVersion,
)
from prisma.types import AgentGraphWhereInput from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field from pydantic.fields import computed_field
@ -529,7 +535,6 @@ async def get_execution(user_id: str, execution_id: str) -> GraphExecution | Non
async def get_graph( async def get_graph(
graph_id: str, graph_id: str,
version: int | None = None, version: int | None = None,
template: bool = False,
user_id: str | None = None, user_id: str | None = None,
for_export: bool = False, for_export: bool = False,
) -> GraphModel | None: ) -> GraphModel | None:
@ -543,21 +548,36 @@ async def get_graph(
where_clause: AgentGraphWhereInput = { where_clause: AgentGraphWhereInput = {
"id": graph_id, "id": graph_id,
} }
if version is not None: if version is not None:
where_clause["version"] = version where_clause["version"] = version
elif not template: else:
where_clause["isActive"] = True where_clause["isActive"] = True
# TODO: Fix hack workaround to get adding store agents to work
if user_id is not None and not template:
where_clause["userId"] = user_id
graph = await AgentGraph.prisma().find_first( graph = await AgentGraph.prisma().find_first(
where=where_clause, where=where_clause,
include=AGENT_GRAPH_INCLUDE, include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"}, order={"version": "desc"},
) )
return GraphModel.from_db(graph, for_export) if graph else None
# The Graph has to be owned by the user or a store listing.
if (
graph is None
or graph.userId != user_id
and not (
await StoreListingVersion.prisma().find_first(
where=prisma.types.StoreListingVersionWhereInput(
agentId=graph_id,
agentVersion=version or graph.version,
isDeleted=False,
StoreListing={"is": {"isApproved": True}},
)
)
)
):
return None
return GraphModel.from_db(graph, for_export)
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None: async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
@ -611,9 +631,7 @@ async def create_graph(graph: Graph, user_id: str) -> GraphModel:
async with transaction() as tx: async with transaction() as tx:
await __create_graph(tx, graph, user_id) await __create_graph(tx, graph, user_id)
if created_graph := await get_graph( if created_graph := await get_graph(graph.id, graph.version, user_id=user_id):
graph.id, graph.version, graph.is_template, user_id=user_id
):
return created_graph return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB") raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import base64
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -206,20 +207,35 @@ class OAuth2Credentials(_BaseCredentials):
class APIKeyCredentials(_BaseCredentials): class APIKeyCredentials(_BaseCredentials):
type: Literal["api_key"] = "api_key" type: Literal["api_key"] = "api_key"
api_key: SecretStr api_key: SecretStr
expires_at: Optional[int] expires_at: Optional[int] = Field(
default=None,
description="Unix timestamp (seconds) indicating when the API key expires (if at all)",
)
"""Unix timestamp (seconds) indicating when the API key expires (if at all)""" """Unix timestamp (seconds) indicating when the API key expires (if at all)"""
def bearer(self) -> str: def bearer(self) -> str:
return f"Bearer {self.api_key.get_secret_value()}" return f"Bearer {self.api_key.get_secret_value()}"
class UserPasswordCredentials(_BaseCredentials):
type: Literal["user_password"] = "user_password"
username: SecretStr
password: SecretStr
def bearer(self) -> str:
# Converting the string to bytes using encode()
# Base64 encoding it with base64.b64encode()
# Converting the resulting bytes back to a string with decode()
return f"Basic {base64.b64encode(f'{self.username.get_secret_value()}:{self.password.get_secret_value()}'.encode()).decode()}"
Credentials = Annotated[ Credentials = Annotated[
OAuth2Credentials | APIKeyCredentials, OAuth2Credentials | APIKeyCredentials | UserPasswordCredentials,
Field(discriminator="type"), Field(discriminator="type"),
] ]
CredentialsType = Literal["api_key", "oauth2"] CredentialsType = Literal["api_key", "oauth2", "user_password"]
class OAuthState(BaseModel): class OAuthState(BaseModel):

View File

@ -780,7 +780,8 @@ class ExecutionManager(AppService):
graph_id: str, graph_id: str,
data: BlockInput, data: BlockInput,
user_id: str, user_id: str,
graph_version: int | None = None, graph_version: int,
preset_id: str | None = None,
) -> GraphExecutionEntry: ) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph( graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version graph_id=graph_id, user_id=user_id, version=graph_version
@ -829,6 +830,7 @@ class ExecutionManager(AppService):
graph_version=graph.version, graph_version=graph.version,
nodes_input=nodes_input, nodes_input=nodes_input,
user_id=user_id, user_id=user_id,
preset_id=preset_id,
) )
starting_node_execs = [] starting_node_execs = []

View File

@ -63,7 +63,10 @@ def execute_graph(**kwargs):
try: try:
log(f"Executing recurring job for graph #{args.graph_id}") log(f"Executing recurring job for graph #{args.graph_id}")
get_execution_client().add_execution( get_execution_client().add_execution(
args.graph_id, args.input_data, args.user_id graph_id=args.graph_id,
data=args.input_data,
user_id=args.user_id,
graph_version=args.graph_version,
) )
except Exception as e: except Exception as e:
logger.exception(f"Error executing graph {args.graph_id}: {e}") logger.exception(f"Error executing graph {args.graph_id}: {e}")

View File

@ -25,9 +25,11 @@ class ProviderName(str, Enum):
OPENWEATHERMAP = "openweathermap" OPENWEATHERMAP = "openweathermap"
OPEN_ROUTER = "open_router" OPEN_ROUTER = "open_router"
PINECONE = "pinecone" PINECONE = "pinecone"
REDDIT = "reddit"
REPLICATE = "replicate" REPLICATE = "replicate"
REVID = "revid" REVID = "revid"
SLANT3D = "slant3d" SLANT3D = "slant3d"
SMTP = "smtp"
TWITTER = "twitter" TWITTER = "twitter"
UNREAL_SPEECH = "unreal_speech" UNREAL_SPEECH = "unreal_speech"
# --8<-- [end:ProviderName] # --8<-- [end:ProviderName]

View File

@ -168,7 +168,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
id = str(uuid4()) id = str(uuid4())
secret = secrets.token_hex(32) secret = secrets.token_hex(32)
provider_name = self.PROVIDER_NAME provider_name: ProviderName = self.PROVIDER_NAME
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id) ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
if register: if register:
if not credentials: if not credentials:

View File

@ -1,7 +1,7 @@
import logging import logging
from backend.data import integrations from backend.data import integrations
from backend.data.model import APIKeyCredentials, Credentials, OAuth2Credentials from backend.data.model import Credentials
from ._base import WT, BaseWebhooksManager from ._base import WT, BaseWebhooksManager
@ -25,6 +25,6 @@ class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
async def _deregister_webhook( async def _deregister_webhook(
self, self,
webhook: integrations.Webhook, webhook: integrations.Webhook,
credentials: OAuth2Credentials | APIKeyCredentials, credentials: Credentials,
) -> None: ) -> None:
pass pass

View File

@ -2,7 +2,7 @@ import logging
from typing import TYPE_CHECKING, Annotated, Literal from typing import TYPE_CHECKING, Annotated, Literal
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field
from backend.data.graph import set_node_webhook from backend.data.graph import set_node_webhook
from backend.data.integrations import ( from backend.data.integrations import (
@ -12,12 +12,7 @@ from backend.data.integrations import (
publish_webhook_event, publish_webhook_event,
wait_for_webhook_event, wait_for_webhook_event,
) )
from backend.data.model import ( from backend.data.model import Credentials, CredentialsType, OAuth2Credentials
APIKeyCredentials,
Credentials,
CredentialsType,
OAuth2Credentials,
)
from backend.executor.manager import ExecutionManager from backend.executor.manager import ExecutionManager
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME from backend.integrations.oauth import HANDLERS_BY_NAME
@ -199,22 +194,15 @@ def get_credential(
@router.post("/{provider}/credentials", status_code=201) @router.post("/{provider}/credentials", status_code=201)
def create_api_key_credentials( def create_credentials(
user_id: Annotated[str, Depends(get_user_id)], user_id: Annotated[str, Depends(get_user_id)],
provider: Annotated[ provider: Annotated[
ProviderName, Path(title="The provider to create credentials for") ProviderName, Path(title="The provider to create credentials for")
], ],
api_key: Annotated[str, Body(title="The API key to store")], credential: Credentials,
title: Annotated[str, Body(title="Optional title for the credentials")], ) -> Credentials:
expires_at: Annotated[ new_credentials = credential.__class__(
int | None, Body(title="Unix timestamp when the key expires") provider=provider, **credential.model_dump(exclude={"provider"})
] = None,
) -> APIKeyCredentials:
new_credentials = APIKeyCredentials(
provider=provider,
api_key=SecretStr(api_key),
title=title,
expires_at=expires_at,
) )
try: try:
@ -320,7 +308,8 @@ async def webhook_ingress_generic(
continue continue
logger.debug(f"Executing graph #{node.graph_id} node #{node.id}") logger.debug(f"Executing graph #{node.graph_id} node #{node.id}")
executor.add_execution( executor.add_execution(
node.graph_id, graph_id=node.graph_id,
graph_version=node.graph_version,
data={f"webhook_{webhook_id}_payload": payload}, data={f"webhook_{webhook_id}_payload": payload},
user_id=webhook.user_id, user_id=webhook.user_id,
) )

View File

@ -56,3 +56,18 @@ class SetGraphActiveVersion(pydantic.BaseModel):
class UpdatePermissionsRequest(pydantic.BaseModel): class UpdatePermissionsRequest(pydantic.BaseModel):
permissions: List[APIKeyPermission] permissions: List[APIKeyPermission]
class Pagination(pydantic.BaseModel):
total_items: int = pydantic.Field(
description="Total number of items.", examples=[42]
)
total_pages: int = pydantic.Field(
description="Total number of pages.", examples=[2]
)
current_page: int = pydantic.Field(
description="Current_page page number.", examples=[1]
)
page_size: int = pydantic.Field(
description="Number of items per page.", examples=[25]
)

View File

@ -2,6 +2,7 @@ import contextlib
import logging import logging
import typing import typing
import autogpt_libs.auth.models
import fastapi import fastapi
import fastapi.responses import fastapi.responses
import starlette.middleware.cors import starlette.middleware.cors
@ -16,7 +17,9 @@ import backend.data.db
import backend.data.graph import backend.data.graph
import backend.data.user import backend.data.user
import backend.server.routers.v1 import backend.server.routers.v1
import backend.server.v2.library.model
import backend.server.v2.library.routes import backend.server.v2.library.routes
import backend.server.v2.store.model
import backend.server.v2.store.routes import backend.server.v2.store.routes
import backend.util.service import backend.util.service
import backend.util.settings import backend.util.settings
@ -117,9 +120,24 @@ class AgentServer(backend.util.service.AppProcess):
@staticmethod @staticmethod
async def test_execute_graph( async def test_execute_graph(
graph_id: str, node_input: dict[typing.Any, typing.Any], user_id: str graph_id: str,
graph_version: int,
node_input: dict[typing.Any, typing.Any],
user_id: str,
): ):
return backend.server.routers.v1.execute_graph(graph_id, node_input, user_id) return backend.server.routers.v1.execute_graph(
graph_id, graph_version, node_input, user_id
)
@staticmethod
async def test_get_graph(
graph_id: str,
graph_version: int,
user_id: str,
):
return await backend.server.routers.v1.get_graph(
graph_id, user_id, graph_version
)
@staticmethod @staticmethod
async def test_create_graph( async def test_create_graph(
@ -149,5 +167,71 @@ class AgentServer(backend.util.service.AppProcess):
async def test_delete_graph(graph_id: str, user_id: str): async def test_delete_graph(graph_id: str, user_id: str):
return await backend.server.routers.v1.delete_graph(graph_id, user_id) return await backend.server.routers.v1.delete_graph(graph_id, user_id)
@staticmethod
async def test_get_presets(user_id: str, page: int = 1, page_size: int = 10):
return await backend.server.v2.library.routes.presets.get_presets(
user_id=user_id, page=page, page_size=page_size
)
@staticmethod
async def test_get_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.get_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.create_preset(
preset=preset, user_id=user_id
)
@staticmethod
async def test_update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: str,
):
return await backend.server.v2.library.routes.presets.update_preset(
preset_id=preset_id, preset=preset, user_id=user_id
)
@staticmethod
async def test_delete_preset(preset_id: str, user_id: str):
return await backend.server.v2.library.routes.presets.delete_preset(
preset_id=preset_id, user_id=user_id
)
@staticmethod
async def test_execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: dict[typing.Any, typing.Any],
user_id: str,
):
return await backend.server.v2.library.routes.presets.execute_preset(
graph_id=graph_id,
graph_version=graph_version,
preset_id=preset_id,
node_input=node_input,
user_id=user_id,
)
@staticmethod
async def test_create_store_listing(
request: backend.server.v2.store.model.StoreSubmissionRequest, user_id: str
):
return await backend.server.v2.store.routes.create_submission(request, user_id)
@staticmethod
async def test_review_store_listing(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: autogpt_libs.auth.models.User,
):
return await backend.server.v2.store.routes.review_submission(request, user)
def set_test_dependency_overrides(self, overrides: dict): def set_test_dependency_overrides(self, overrides: dict):
app.dependency_overrides.update(overrides) app.dependency_overrides.update(overrides)

View File

@ -13,6 +13,7 @@ from typing_extensions import Optional, TypedDict
import backend.data.block import backend.data.block
import backend.server.integrations.router import backend.server.integrations.router
import backend.server.routers.analytics import backend.server.routers.analytics
import backend.server.v2.library.db
from backend.data import execution as execution_db from backend.data import execution as execution_db
from backend.data import graph as graph_db from backend.data import graph as graph_db
from backend.data.api_key import ( from backend.data.api_key import (
@ -180,11 +181,6 @@ async def get_graph(
tags=["graphs"], tags=["graphs"],
dependencies=[Depends(auth_middleware)], dependencies=[Depends(auth_middleware)],
) )
@v1_router.get(
path="/templates/{graph_id}/versions",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_all_versions( async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Depends(get_user_id)] graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]: ) -> Sequence[graph_db.GraphModel]:
@ -200,12 +196,11 @@ async def get_graph_all_versions(
async def create_new_graph( async def create_new_graph(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)] create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel: ) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=False, user_id=user_id) return await do_create_graph(create_graph, user_id=user_id)
async def do_create_graph( async def do_create_graph(
create_graph: CreateGraph, create_graph: CreateGraph,
is_template: bool,
# user_id doesn't have to be annotated like on other endpoints, # user_id doesn't have to be annotated like on other endpoints,
# because create_graph isn't used directly as an endpoint # because create_graph isn't used directly as an endpoint
user_id: str, user_id: str,
@ -217,7 +212,6 @@ async def do_create_graph(
graph = await graph_db.get_graph( graph = await graph_db.get_graph(
create_graph.template_id, create_graph.template_id,
create_graph.template_version, create_graph.template_version,
template=True,
user_id=user_id, user_id=user_id,
) )
if not graph: if not graph:
@ -225,13 +219,18 @@ async def do_create_graph(
400, detail=f"Template #{create_graph.template_id} not found" 400, detail=f"Template #{create_graph.template_id} not found"
) )
graph.version = 1 graph.version = 1
# Create a library agent for the new graph
await backend.server.v2.library.db.create_library_agent(
graph.id,
graph.version,
user_id,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=400, detail="Either graph or template_id must be provided." status_code=400, detail="Either graph or template_id must be provided."
) )
graph.is_template = is_template
graph.is_active = not is_template
graph.reassign_ids(user_id=user_id, reassign_graph_id=True) graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
graph = await graph_db.create_graph(graph, user_id=user_id) graph = await graph_db.create_graph(graph, user_id=user_id)
@ -261,11 +260,6 @@ async def delete_graph(
@v1_router.put( @v1_router.put(
path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)] path="/graphs/{graph_id}", tags=["graphs"], dependencies=[Depends(auth_middleware)]
) )
@v1_router.put(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def update_graph( async def update_graph(
graph_id: str, graph_id: str,
graph: graph_db.Graph, graph: graph_db.Graph,
@ -298,6 +292,11 @@ async def update_graph(
if new_graph_version.is_active: if new_graph_version.is_active:
# Keep the library agent up to date with the new active version
await backend.server.v2.library.db.update_agent_version_in_library(
user_id, graph.id, graph.version
)
def get_credentials(credentials_id: str) -> "Credentials | None": def get_credentials(credentials_id: str) -> "Credentials | None":
return integration_creds_manager.get(user_id, credentials_id) return integration_creds_manager.get(user_id, credentials_id)
@ -353,6 +352,12 @@ async def set_graph_active_version(
version=new_active_version, version=new_active_version,
user_id=user_id, user_id=user_id,
) )
# Keep the library agent up to date with the new active version
await backend.server.v2.library.db.update_agent_version_in_library(
user_id, new_active_graph.id, new_active_graph.version
)
if current_active_graph and current_active_graph.version != new_active_version: if current_active_graph and current_active_graph.version != new_active_version:
# Handle deactivation of the previously active version # Handle deactivation of the previously active version
await on_graph_deactivate( await on_graph_deactivate(
@ -368,12 +373,13 @@ async def set_graph_active_version(
) )
def execute_graph( def execute_graph(
graph_id: str, graph_id: str,
graph_version: int,
node_input: dict[Any, Any], node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)], user_id: Annotated[str, Depends(get_user_id)],
) -> dict[str, Any]: # FIXME: add proper return type ) -> dict[str, Any]: # FIXME: add proper return type
try: try:
graph_exec = execution_manager_client().add_execution( graph_exec = execution_manager_client().add_execution(
graph_id, node_input, user_id=user_id graph_id, node_input, user_id=user_id, graph_version=graph_version
) )
return {"id": graph_exec.graph_exec_id} return {"id": graph_exec.graph_exec_id}
except Exception as e: except Exception as e:
@ -428,47 +434,6 @@ async def get_graph_run_node_execution_results(
return await execution_db.get_execution_results(graph_exec_id) return await execution_db.get_execution_results(graph_exec_id)
########################################################
##################### Templates ########################
########################################################
@v1_router.get(
path="/templates",
tags=["graphs", "templates"],
dependencies=[Depends(auth_middleware)],
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
@v1_router.get(
path="/templates/{graph_id}",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_template(
graph_id: str, version: int | None = None
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(graph_id, version, template=True)
if not graph:
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
return graph
@v1_router.post(
path="/templates",
tags=["templates", "graphs"],
dependencies=[Depends(auth_middleware)],
)
async def create_new_template(
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
) -> graph_db.GraphModel:
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
######################################################## ########################################################
##################### Schedules ######################## ##################### Schedules ########################
######################################################## ########################################################

View File

@ -1,5 +1,5 @@
import json
import logging import logging
from typing import List
import prisma.errors import prisma.errors
import prisma.models import prisma.models
@ -7,6 +7,7 @@ import prisma.types
import backend.data.graph import backend.data.graph
import backend.data.includes import backend.data.includes
import backend.server.model
import backend.server.v2.library.model import backend.server.v2.library.model
import backend.server.v2.store.exceptions import backend.server.v2.store.exceptions
@ -14,90 +15,152 @@ logger = logging.getLogger(__name__)
async def get_library_agents( async def get_library_agents(
user_id: str, user_id: str, search_query: str | None = None
) -> List[backend.server.v2.library.model.LibraryAgent]: ) -> list[backend.server.v2.library.model.LibraryAgent]:
""" logger.debug(
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table) f"Fetching library agents for user_id={user_id} search_query={search_query}"
""" )
logger.debug(f"Getting library agents for user {user_id}")
try: if search_query and len(search_query.strip()) > 100:
# Get agents created by user with nodes and links logger.warning(f"Search query too long: {search_query}")
user_created = await prisma.models.AgentGraph.prisma().find_many( raise backend.server.v2.store.exceptions.DatabaseError(
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True), "Search query is too long."
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
) )
# Get agents in user's library with nodes and links where_clause = prisma.types.LibraryAgentWhereInput(
library_agents = await prisma.models.UserAgent.prisma().find_many( userId=user_id,
where=prisma.types.UserAgentWhereInput( isDeleted=False,
userId=user_id, isDeleted=False, isArchived=False isArchived=False,
), )
include={
if search_query:
where_clause["OR"] = [
{
"Agent": { "Agent": {
"include": { "is": {"name": {"contains": search_query, "mode": "insensitive"}}
"AgentNodes": { }
"include": { },
"Input": True, {
"Output": True, "Agent": {
"Webhook": True, "is": {
"AgentBlock": True, "description": {"contains": search_query, "mode": "insensitive"}
}
}
} }
} }
}, },
]
try:
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
where=where_clause,
include={
"Agent": {
"include": {
"AgentNodes": {"include": {"Input": True, "Output": True}}
}
}
},
order=[{"updatedAt": "desc"}],
)
logger.debug(f"Retrieved {len(library_agents)} agents for user_id={user_id}.")
return [
backend.server.v2.library.model.LibraryAgent.from_db(agent)
for agent in library_agents
]
except prisma.errors.PrismaError as e:
logger.error(f"Database error fetching library agents: {e}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Unable to fetch library agents."
) )
# Convert to Graph models first
graphs = []
# Add user created agents async def create_library_agent(
for agent in user_created: agent_id: str, agent_version: int, user_id: str
try: ) -> prisma.models.LibraryAgent:
graphs.append(backend.data.graph.GraphModel.from_db(agent)) """
except Exception as e: Adds an agent to the user's library (LibraryAgent table)
logger.error(f"Error processing user created agent {agent.id}: {e}") """
continue
# Add library agents try:
for agent in library_agents:
if agent.Agent:
try:
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
except Exception as e:
logger.error(f"Error processing library agent {agent.agentId}: {e}")
continue
# Convert Graph models to LibraryAgent models library_agent = await prisma.models.LibraryAgent.prisma().create(
result = [] data=prisma.types.LibraryAgentCreateInput(
for graph in graphs: userId=user_id,
result.append( agentId=agent_id,
backend.server.v2.library.model.LibraryAgent( agentVersion=agent_version,
id=graph.id, isCreatedByUser=False,
version=graph.version, useGraphIsActiveVersion=True,
is_active=graph.is_active,
name=graph.name,
description=graph.description,
isCreatedByUser=any(a.id == graph.id for a in user_created),
input_schema=graph.input_schema,
output_schema=graph.output_schema,
)
) )
)
logger.debug(f"Found {len(result)} library agents") return library_agent
return result
except prisma.errors.PrismaError as e: except prisma.errors.PrismaError as e:
logger.error(f"Database error getting library agents: {str(e)}") logger.error(f"Database error creating agent to library: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError( raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch library agents" "Failed to create agent to library"
) from e ) from e
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None: async def update_agent_version_in_library(
user_id: str, agent_id: str, agent_version: int
) -> None:
""" """
Finds the agent from the store listing version and adds it to the user's library (UserAgent table) Updates the agent version in the library
"""
try:
await prisma.models.LibraryAgent.prisma().update(
where={
"userId": user_id,
"agentId": agent_id,
"useGraphIsActiveVersion": True,
},
data=prisma.types.LibraryAgentUpdateInput(
Agent=prisma.types.AgentGraphUpdateOneWithoutRelationsInput(
connect=prisma.types.AgentGraphWhereUniqueInput(
id=agent_id,
version=agent_version,
),
),
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating agent version in library: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update agent version in library"
) from e
async def update_library_agent(
library_agent_id: str,
user_id: str,
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> None:
"""
Updates the library agent with the given fields
"""
try:
await prisma.models.LibraryAgent.prisma().update(
where={"id": library_agent_id, "userId": user_id},
data=prisma.types.LibraryAgentUpdateInput(
useGraphIsActiveVersion=auto_update_version,
isFavorite=is_favorite,
isArchived=is_archived,
isDeleted=is_deleted,
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error updating library agent: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to update library agent"
) from e
async def add_store_agent_to_library(
store_listing_version_id: str, user_id: str
) -> None:
"""
Finds the agent from the store listing version and adds it to the user's library (LibraryAgent table)
if they don't already have it if they don't already have it
""" """
logger.debug( logger.debug(
@ -131,7 +194,7 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
) )
# Check if user already has this agent # Check if user already has this agent
existing_user_agent = await prisma.models.UserAgent.prisma().find_first( existing_user_agent = await prisma.models.LibraryAgent.prisma().find_first(
where={ where={
"userId": user_id, "userId": user_id,
"agentId": agent.id, "agentId": agent.id,
@ -145,9 +208,9 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
) )
return return
# Create UserAgent entry # Create LibraryAgent entry
await prisma.models.UserAgent.prisma().create( await prisma.models.LibraryAgent.prisma().create(
data=prisma.types.UserAgentCreateInput( data=prisma.types.LibraryAgentCreateInput(
userId=user_id, userId=user_id,
agentId=agent.id, agentId=agent.id,
agentVersion=agent.version, agentVersion=agent.version,
@ -163,3 +226,116 @@ async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> N
raise backend.server.v2.store.exceptions.DatabaseError( raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to add agent to library" "Failed to add agent to library"
) from e ) from e
##############################################
########### Presets DB Functions #############
##############################################
async def get_presets(
user_id: str, page: int, page_size: int
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
try:
presets = await prisma.models.AgentPreset.prisma().find_many(
where={"userId": user_id},
skip=page * page_size,
take=page_size,
)
total_items = await prisma.models.AgentPreset.prisma().count(
where={"userId": user_id},
)
total_pages = (total_items + page_size - 1) // page_size
presets = [
backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
for preset in presets
]
return backend.server.v2.library.model.LibraryAgentPresetResponse(
presets=presets,
pagination=backend.server.model.Pagination(
total_items=total_items,
total_pages=total_pages,
current_page=page,
page_size=page_size,
),
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting presets: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch presets"
) from e
async def get_preset(
user_id: str, preset_id: str
) -> backend.server.v2.library.model.LibraryAgentPreset | None:
try:
preset = await prisma.models.AgentPreset.prisma().find_unique(
where={"id": preset_id, "userId": user_id}, include={"InputPresets": True}
)
if not preset:
return None
return backend.server.v2.library.model.LibraryAgentPreset.from_db(preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error getting preset: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch preset"
) from e
async def create_or_update_preset(
user_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
preset_id: str | None = None,
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
new_preset = await prisma.models.AgentPreset.prisma().upsert(
where={
"id": preset_id if preset_id else "",
},
data={
"create": {
"userId": user_id,
"name": preset.name,
"description": preset.description,
"agentId": preset.agent_id,
"agentVersion": preset.agent_version,
"isActive": preset.is_active,
"InputPresets": {
"create": [
{"name": name, "data": json.dumps(data)}
for name, data in preset.inputs.items()
]
},
},
"update": {
"name": preset.name,
"description": preset.description,
"isActive": preset.is_active,
},
},
)
return backend.server.v2.library.model.LibraryAgentPreset.from_db(new_preset)
except prisma.errors.PrismaError as e:
logger.error(f"Database error creating preset: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to create preset"
) from e
async def delete_preset(user_id: str, preset_id: str) -> None:
try:
await prisma.models.AgentPreset.prisma().update(
where={"id": preset_id, "userId": user_id},
data={"isDeleted": True},
)
except prisma.errors.PrismaError as e:
logger.error(f"Database error deleting preset: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to delete preset"
) from e

View File

@ -37,7 +37,7 @@ async def test_get_library_agents(mocker):
] ]
mock_library_agents = [ mock_library_agents = [
prisma.models.UserAgent( prisma.models.LibraryAgent(
id="ua1", id="ua1",
userId="test-user", userId="test-user",
agentId="agent2", agentId="agent2",
@ -48,6 +48,7 @@ async def test_get_library_agents(mocker):
createdAt=datetime.now(), createdAt=datetime.now(),
updatedAt=datetime.now(), updatedAt=datetime.now(),
isFavorite=False, isFavorite=False,
useGraphIsActiveVersion=True,
Agent=prisma.models.AgentGraph( Agent=prisma.models.AgentGraph(
id="agent2", id="agent2",
version=1, version=1,
@ -67,8 +68,8 @@ async def test_get_library_agents(mocker):
return_value=mock_user_created return_value=mock_user_created
) )
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma") mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_user_agent.return_value.find_many = mocker.AsyncMock( mock_library_agent.return_value.find_many = mocker.AsyncMock(
return_value=mock_library_agents return_value=mock_library_agents
) )
@ -76,40 +77,16 @@ async def test_get_library_agents(mocker):
result = await db.get_library_agents("test-user") result = await db.get_library_agents("test-user")
# Verify results # Verify results
assert len(result) == 2 assert len(result) == 1
assert result[0].id == "agent1" assert result[0].id == "ua1"
assert result[0].name == "Test Agent 1" assert result[0].name == "Test Agent 2"
assert result[0].description == "Test Description 1" assert result[0].description == "Test Description 2"
assert result[0].isCreatedByUser is True assert result[0].is_created_by_user is False
assert result[1].id == "agent2" assert result[0].is_latest_version is True
assert result[1].name == "Test Agent 2" assert result[0].is_favorite is False
assert result[1].description == "Test Description 2" assert result[0].agent_id == "agent2"
assert result[1].isCreatedByUser is False assert result[0].agent_version == 1
assert result[0].preset_id is None
# Verify mocks called correctly
mock_agent_graph.return_value.find_many.assert_called_once_with(
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
)
mock_user_agent.return_value.find_many.assert_called_once_with(
where=prisma.types.UserAgentWhereInput(
userId="test-user", isDeleted=False, isArchived=False
),
include={
"Agent": {
"include": {
"AgentNodes": {
"include": {
"Input": True,
"Output": True,
"Webhook": True,
"AgentBlock": True,
}
}
}
}
},
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -152,26 +129,26 @@ async def test_add_agent_to_library(mocker):
return_value=mock_store_listing return_value=mock_store_listing
) )
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma") mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None) mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_user_agent.return_value.create = mocker.AsyncMock() mock_library_agent.return_value.create = mocker.AsyncMock()
# Call function # Call function
await db.add_agent_to_library("version123", "test-user") await db.add_store_agent_to_library("version123", "test-user")
# Verify mocks called correctly # Verify mocks called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with( mock_store_listing_version.return_value.find_unique.assert_called_once_with(
where={"id": "version123"}, include={"Agent": True} where={"id": "version123"}, include={"Agent": True}
) )
mock_user_agent.return_value.find_first.assert_called_once_with( mock_library_agent.return_value.find_first.assert_called_once_with(
where={ where={
"userId": "test-user", "userId": "test-user",
"agentId": "agent1", "agentId": "agent1",
"agentVersion": 1, "agentVersion": 1,
} }
) )
mock_user_agent.return_value.create.assert_called_once_with( mock_library_agent.return_value.create.assert_called_once_with(
data=prisma.types.UserAgentCreateInput( data=prisma.types.LibraryAgentCreateInput(
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
) )
) )
@ -189,7 +166,7 @@ async def test_add_agent_to_library_not_found(mocker):
# Call function and verify exception # Call function and verify exception
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError): with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
await db.add_agent_to_library("version123", "test-user") await db.add_store_agent_to_library("version123", "test-user")
# Verify mock called correctly # Verify mock called correctly
mock_store_listing_version.return_value.find_unique.assert_called_once_with( mock_store_listing_version.return_value.find_unique.assert_called_once_with(

View File

@ -1,16 +1,112 @@
import datetime
import json
import typing import typing
import prisma.models
import pydantic import pydantic
import backend.data.block
import backend.data.graph
import backend.server.model
class LibraryAgent(pydantic.BaseModel): class LibraryAgent(pydantic.BaseModel):
id: str # Changed from agent_id to match GraphMeta id: str # Changed from agent_id to match GraphMeta
version: int # Changed from agent_version to match GraphMeta
is_active: bool # Added to match GraphMeta agent_id: str
agent_version: int # Changed from agent_version to match GraphMeta
preset_id: str | None
updated_at: datetime.datetime
name: str name: str
description: str description: str
isCreatedByUser: bool
# Made input_schema and output_schema match GraphMeta's type # Made input_schema and output_schema match GraphMeta's type
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
is_favorite: bool
is_created_by_user: bool
is_latest_version: bool
@staticmethod
def from_db(agent: prisma.models.LibraryAgent):
if not agent.Agent:
raise ValueError("AgentGraph is required")
graph = backend.data.graph.GraphModel.from_db(agent.Agent)
agent_updated_at = agent.Agent.updatedAt
lib_agent_updated_at = agent.updatedAt
# Take the latest updated_at timestamp either when the graph was updated or the library agent was updated
updated_at = (
max(agent_updated_at, lib_agent_updated_at)
if agent_updated_at
else lib_agent_updated_at
)
return LibraryAgent(
id=agent.id,
agent_id=agent.agentId,
agent_version=agent.agentVersion,
updated_at=updated_at,
name=graph.name,
description=graph.description,
input_schema=graph.input_schema,
output_schema=graph.output_schema,
is_favorite=agent.isFavorite,
is_created_by_user=agent.isCreatedByUser,
is_latest_version=graph.is_active,
preset_id=agent.AgentPreset.id if agent.AgentPreset else None,
)
class LibraryAgentPreset(pydantic.BaseModel):
id: str
updated_at: datetime.datetime
agent_id: str
agent_version: int
name: str
description: str
is_active: bool
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
@staticmethod
def from_db(preset: prisma.models.AgentPreset):
input_data = {}
for data in preset.InputPresets or []:
input_data[data.name] = json.loads(data.data)
return LibraryAgentPreset(
id=preset.id,
updated_at=preset.updatedAt,
agent_id=preset.agentId,
agent_version=preset.agentVersion,
name=preset.name,
description=preset.description,
is_active=preset.isActive,
inputs=input_data,
)
class LibraryAgentPresetResponse(pydantic.BaseModel):
presets: list[LibraryAgentPreset]
pagination: backend.server.model.Pagination
class CreateLibraryAgentPresetRequest(pydantic.BaseModel):
name: str
description: str
inputs: dict[str, typing.Union[backend.data.block.BlockInput, typing.Any]]
agent_id: str
agent_version: int
is_active: bool

View File

@ -1,23 +1,35 @@
import datetime
import prisma.models
import backend.data.block
import backend.server.model
import backend.server.v2.library.model import backend.server.v2.library.model
def test_library_agent(): def test_library_agent():
agent = backend.server.v2.library.model.LibraryAgent( agent = backend.server.v2.library.model.LibraryAgent(
id="test-agent-123", id="test-agent-123",
version=1, agent_id="agent-123",
is_active=True, agent_version=1,
preset_id=None,
updated_at=datetime.datetime.now(),
name="Test Agent", name="Test Agent",
description="Test description", description="Test description",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}}, input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
) )
assert agent.id == "test-agent-123" assert agent.id == "test-agent-123"
assert agent.version == 1 assert agent.agent_id == "agent-123"
assert agent.is_active is True assert agent.agent_version == 1
assert agent.name == "Test Agent" assert agent.name == "Test Agent"
assert agent.description == "Test description" assert agent.description == "Test description"
assert agent.isCreatedByUser is False assert agent.is_favorite is False
assert agent.is_created_by_user is False
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}} assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}} assert agent.output_schema == {"type": "object", "properties": {}}
@ -25,19 +37,148 @@ def test_library_agent():
def test_library_agent_with_user_created(): def test_library_agent_with_user_created():
agent = backend.server.v2.library.model.LibraryAgent( agent = backend.server.v2.library.model.LibraryAgent(
id="user-agent-456", id="user-agent-456",
version=2, agent_id="agent-456",
is_active=True, agent_version=2,
preset_id=None,
updated_at=datetime.datetime.now(),
name="User Created Agent", name="User Created Agent",
description="An agent created by the user", description="An agent created by the user",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}}, input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
) )
assert agent.id == "user-agent-456" assert agent.id == "user-agent-456"
assert agent.version == 2 assert agent.agent_id == "agent-456"
assert agent.is_active is True assert agent.agent_version == 2
assert agent.name == "User Created Agent" assert agent.name == "User Created Agent"
assert agent.description == "An agent created by the user" assert agent.description == "An agent created by the user"
assert agent.isCreatedByUser is True assert agent.is_favorite is False
assert agent.is_created_by_user is True
assert agent.is_latest_version is True
assert agent.input_schema == {"type": "object", "properties": {}} assert agent.input_schema == {"type": "object", "properties": {}}
assert agent.output_schema == {"type": "object", "properties": {}} assert agent.output_schema == {"type": "object", "properties": {}}
def test_library_agent_preset():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"input1": backend.data.block.BlockInput(
name="input1",
data={"type": "string", "value": "test value"},
)
},
updated_at=datetime.datetime.now(),
)
assert preset.id == "preset-123"
assert preset.name == "Test Preset"
assert preset.description == "Test preset description"
assert preset.agent_id == "test-agent-123"
assert preset.agent_version == 1
assert preset.is_active is True
assert preset.inputs == {
"input1": backend.data.block.BlockInput(
name="input1", data={"type": "string", "value": "test value"}
)
}
def test_library_agent_preset_response():
preset = backend.server.v2.library.model.LibraryAgentPreset(
id="preset-123",
name="Test Preset",
description="Test preset description",
agent_id="test-agent-123",
agent_version=1,
is_active=True,
inputs={
"input1": backend.data.block.BlockInput(
name="input1",
data={"type": "string", "value": "test value"},
)
},
updated_at=datetime.datetime.now(),
)
pagination = backend.server.model.Pagination(
total_items=1, total_pages=1, current_page=1, page_size=10
)
response = backend.server.v2.library.model.LibraryAgentPresetResponse(
presets=[preset], pagination=pagination
)
assert len(response.presets) == 1
assert response.presets[0].id == "preset-123"
assert response.pagination.total_items == 1
assert response.pagination.total_pages == 1
assert response.pagination.current_page == 1
assert response.pagination.page_size == 10
def test_create_library_agent_preset_request():
request = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="New Preset",
description="New preset description",
agent_id="agent-123",
agent_version=1,
is_active=True,
inputs={
"input1": backend.data.block.BlockInput(
name="input1",
data={"type": "string", "value": "test value"},
)
},
)
assert request.name == "New Preset"
assert request.description == "New preset description"
assert request.agent_id == "agent-123"
assert request.agent_version == 1
assert request.is_active is True
assert request.inputs == {
"input1": backend.data.block.BlockInput(
name="input1", data={"type": "string", "value": "test value"}
)
}
def test_library_agent_from_db():
# Create mock DB agent
db_agent = prisma.models.AgentPreset(
id="test-agent-123",
createdAt=datetime.datetime.now(),
updatedAt=datetime.datetime.now(),
agentId="agent-123",
agentVersion=1,
name="Test Agent",
description="Test agent description",
isActive=True,
userId="test-user-123",
isDeleted=False,
InputPresets=[
prisma.models.AgentNodeExecutionInputOutput(
id="input-123",
time=datetime.datetime.now(),
name="input1",
data='{"type": "string", "value": "test value"}',
)
],
)
# Convert to LibraryAgentPreset
agent = backend.server.v2.library.model.LibraryAgentPreset.from_db(db_agent)
assert agent.id == "test-agent-123"
assert agent.agent_version == 1
assert agent.is_active is True
assert agent.name == "Test Agent"
assert agent.description == "Test agent description"
assert agent.inputs == {"input1": {"type": "string", "value": "test value"}}

View File

@ -1,123 +0,0 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import fastapi
import prisma
import backend.data.graph
import backend.integrations.creds_manager
import backend.integrations.webhooks.graph_lifecycle_hooks
import backend.server.v2.library.db
import backend.server.v2.library.model
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
integration_creds_manager = (
backend.integrations.creds_manager.IntegrationCredentialsManager()
)
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_library_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await backend.server.v2.library.db.get_library_agents(user_id)
return agents
except Exception:
logger.exception("Exception occurred whilst getting library agents")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
# Get the graph from the store listing
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id}, include={"Agent": True}
)
)
if not store_listing_version or not store_listing_version.Agent:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
agent = store_listing_version.Agent
if agent.userId == user_id:
raise fastapi.HTTPException(
status_code=400, detail="Cannot add own agent to library"
)
# Create a new graph from the template
graph = await backend.data.graph.get_graph(
agent.id, agent.version, template=True, user_id=user_id
)
if not graph:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent {agent.id} not found"
)
# Create a deep copy with new IDs
graph.version = 1
graph.is_template = False
graph.is_active = True
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
# Save the new graph
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
graph = (
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
graph,
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
)
)
return fastapi.Response(status_code=201)
except Exception:
logger.exception("Exception occurred whilst adding agent to library")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)

View File

@ -0,0 +1,9 @@
import fastapi
from .agents import router as agents_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(agents_router)

View File

@ -0,0 +1,148 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import autogpt_libs.utils.cache
import fastapi
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.server.v2.store.exceptions
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
@router.get(
"/agents",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def get_library_agents(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
]
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
"""
Get all agents in the user's library, including both created and saved agents.
"""
try:
agents = await backend.server.v2.library.db.get_library_agents(user_id)
return agents
except Exception as e:
logger.exception(f"Exception occurred whilst getting library agents: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to get library agents"
)
@router.post(
"/agents/{store_listing_version_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=201,
)
async def add_agent_to_library(
store_listing_version_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.Response:
"""
Add an agent from the store to the user's library.
Args:
store_listing_version_id (str): ID of the store listing version to add
user_id (str): ID of the authenticated user
Returns:
fastapi.Response: 201 status code on success
Raises:
HTTPException: If there is an error adding the agent to the library
"""
try:
# Use the database function to add the agent to the library
await backend.server.v2.library.db.add_store_agent_to_library(
store_listing_version_id, user_id
)
return fastapi.Response(status_code=201)
except backend.server.v2.store.exceptions.AgentNotFoundError:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
except backend.server.v2.store.exceptions.DatabaseError as e:
logger.exception(f"Database error occurred whilst adding agent to library: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
except Exception as e:
logger.exception(
f"Unexpected exception occurred whilst adding agent to library: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to add agent to library"
)
@router.put(
"/agents/{library_agent_id}",
tags=["library", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
status_code=204,
)
async def update_library_agent(
library_agent_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
auto_update_version: bool = False,
is_favorite: bool = False,
is_archived: bool = False,
is_deleted: bool = False,
) -> fastapi.Response:
"""
Update the library agent with the given fields.
Args:
library_agent_id (str): ID of the library agent to update
user_id (str): ID of the authenticated user
auto_update_version (bool): Whether to auto-update the agent version
is_favorite (bool): Whether the agent is marked as favorite
is_archived (bool): Whether the agent is archived
is_deleted (bool): Whether the agent is deleted
Returns:
fastapi.Response: 204 status code on success
Raises:
HTTPException: If there is an error updating the library agent
"""
try:
# Use the database function to update the library agent
await backend.server.v2.library.db.update_library_agent(
library_agent_id,
user_id,
auto_update_version,
is_favorite,
is_archived,
is_deleted,
)
return fastapi.Response(status_code=204)
except backend.server.v2.store.exceptions.DatabaseError as e:
logger.exception(f"Database error occurred whilst updating library agent: {e}")
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)
except Exception as e:
logger.exception(
f"Unexpected exception occurred whilst updating library agent: {e}"
)
raise fastapi.HTTPException(
status_code=500, detail="Failed to update library agent"
)

View File

@ -0,0 +1,156 @@
import logging
import typing
import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware
import autogpt_libs.utils.cache
import fastapi
import backend.data.graph
import backend.executor
import backend.integrations.creds_manager
import backend.integrations.webhooks.graph_lifecycle_hooks
import backend.server.v2.library.db
import backend.server.v2.library.model
import backend.util.service
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
integration_creds_manager = (
backend.integrations.creds_manager.IntegrationCredentialsManager()
)
@autogpt_libs.utils.cache.thread_cached
def execution_manager_client() -> backend.executor.ExecutionManager:
return backend.util.service.get_service_client(backend.executor.ExecutionManager)
@router.get("/presets")
async def get_presets(
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
page: int = 1,
page_size: int = 10,
) -> backend.server.v2.library.model.LibraryAgentPresetResponse:
try:
presets = await backend.server.v2.library.db.get_presets(
user_id, page, page_size
)
return presets
except Exception as e:
logger.exception(f"Exception occurred whilst getting presets: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get presets")
@router.get("/presets/{preset_id}")
async def get_preset(
preset_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
if not preset:
raise fastapi.HTTPException(
status_code=404,
detail=f"Preset {preset_id} not found",
)
return preset
except Exception as e:
logger.exception(f"Exception occurred whilst getting preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to get preset")
@router.post("/presets")
async def create_preset(
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
return await backend.server.v2.library.db.create_or_update_preset(
user_id, preset
)
except Exception as e:
logger.exception(f"Exception occurred whilst creating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to create preset")
@router.put("/presets/{preset_id}")
async def update_preset(
preset_id: str,
preset: backend.server.v2.library.model.CreateLibraryAgentPresetRequest,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> backend.server.v2.library.model.LibraryAgentPreset:
try:
return await backend.server.v2.library.db.create_or_update_preset(
user_id, preset, preset_id
)
except Exception as e:
logger.exception(f"Exception occurred whilst updating preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to update preset")
@router.delete("/presets/{preset_id}")
async def delete_preset(
preset_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
):
try:
await backend.server.v2.library.db.delete_preset(user_id, preset_id)
return fastapi.Response(status_code=204)
except Exception as e:
logger.exception(f"Exception occurred whilst deleting preset: {e}")
raise fastapi.HTTPException(status_code=500, detail="Failed to delete preset")
@router.post(
path="/presets/{preset_id}/execute",
tags=["presets"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def execute_preset(
graph_id: str,
graph_version: int,
preset_id: str,
node_input: dict[typing.Any, typing.Any],
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> dict[str, typing.Any]: # FIXME: add proper return type
try:
preset = await backend.server.v2.library.db.get_preset(user_id, preset_id)
if not preset:
raise fastapi.HTTPException(status_code=404, detail="Preset not found")
logger.info(f"Preset inputs: {preset.inputs}")
updated_node_input = node_input.copy()
# Merge in preset input values
for key, value in preset.inputs.items():
if key not in updated_node_input:
updated_node_input[key] = value
execution = execution_manager_client().add_execution(
graph_id=graph_id,
graph_version=graph_version,
data=updated_node_input,
user_id=user_id,
preset_id=preset_id,
)
logger.info(f"Execution added: {execution} with input: {updated_node_input}")
return {"id": execution.graph_exec_id}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise fastapi.HTTPException(status_code=400, detail=msg)

View File

@ -1,3 +1,5 @@
import datetime
import autogpt_libs.auth.depends import autogpt_libs.auth.depends
import autogpt_libs.auth.middleware import autogpt_libs.auth.middleware
import fastapi import fastapi
@ -35,21 +37,29 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
mocked_value = [ mocked_value = [
backend.server.v2.library.model.LibraryAgent( backend.server.v2.library.model.LibraryAgent(
id="test-agent-1", id="test-agent-1",
version=1, agent_id="test-agent-1",
is_active=True, agent_version=1,
preset_id="preset-1",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=True,
is_latest_version=True,
name="Test Agent 1", name="Test Agent 1",
description="Test Description 1", description="Test Description 1",
isCreatedByUser=True,
input_schema={"type": "object", "properties": {}}, input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
), ),
backend.server.v2.library.model.LibraryAgent( backend.server.v2.library.model.LibraryAgent(
id="test-agent-2", id="test-agent-2",
version=1, agent_id="test-agent-2",
is_active=True, agent_version=1,
preset_id="preset-2",
updated_at=datetime.datetime(2023, 1, 1, 0, 0, 0),
is_favorite=False,
is_created_by_user=False,
is_latest_version=True,
name="Test Agent 2", name="Test Agent 2",
description="Test Description 2", description="Test Description 2",
isCreatedByUser=False,
input_schema={"type": "object", "properties": {}}, input_schema={"type": "object", "properties": {}},
output_schema={"type": "object", "properties": {}}, output_schema={"type": "object", "properties": {}},
), ),
@ -65,10 +75,10 @@ def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
for agent in response.json() for agent in response.json()
] ]
assert len(data) == 2 assert len(data) == 2
assert data[0].id == "test-agent-1" assert data[0].agent_id == "test-agent-1"
assert data[0].isCreatedByUser is True assert data[0].is_created_by_user is True
assert data[1].id == "test-agent-2" assert data[1].agent_id == "test-agent-2"
assert data[1].isCreatedByUser is False assert data[1].is_created_by_user is False
mock_db_call.assert_called_once_with("test-user-id") mock_db_call.assert_called_once_with("test-user-id")

View File

@ -325,7 +325,10 @@ async def get_store_submissions(
where = prisma.types.StoreSubmissionWhereInput(user_id=user_id) where = prisma.types.StoreSubmissionWhereInput(user_id=user_id)
# Query submissions from database # Query submissions from database
submissions = await prisma.models.StoreSubmission.prisma().find_many( submissions = await prisma.models.StoreSubmission.prisma().find_many(
where=where, skip=skip, take=page_size, order=[{"date_submitted": "desc"}] where=where,
skip=skip,
take=page_size,
order=[{"date_submitted": "desc"}],
) )
# Get total count for pagination # Get total count for pagination
@ -405,9 +408,7 @@ async def delete_store_submission(
) )
# Delete the submission # Delete the submission
await prisma.models.StoreListing.prisma().delete( await prisma.models.StoreListing.prisma().delete(where={"id": submission.id})
where=prisma.types.StoreListingWhereUniqueInput(id=submission.id)
)
logger.debug( logger.debug(
f"Successfully deleted submission {submission_id} for user {user_id}" f"Successfully deleted submission {submission_id} for user {user_id}"
@ -504,7 +505,15 @@ async def create_store_submission(
"subHeading": sub_heading, "subHeading": sub_heading,
} }
}, },
} },
include={"StoreListingVersions": True},
)
slv_id = (
listing.StoreListingVersions[0].id
if listing.StoreListingVersions is not None
and len(listing.StoreListingVersions) > 0
else None
) )
logger.debug(f"Created store listing for agent {agent_id}") logger.debug(f"Created store listing for agent {agent_id}")
@ -521,6 +530,7 @@ async def create_store_submission(
status=prisma.enums.SubmissionStatus.PENDING, status=prisma.enums.SubmissionStatus.PENDING,
runs=0, runs=0,
rating=0.0, rating=0.0,
store_listing_version_id=slv_id,
) )
except ( except (
@ -811,9 +821,7 @@ async def get_agent(
agent = store_listing_version.Agent agent = store_listing_version.Agent
graph = await backend.data.graph.get_graph( graph = await backend.data.graph.get_graph(agent.id, agent.version)
agent.id, agent.version, template=True
)
if not graph: if not graph:
raise fastapi.HTTPException( raise fastapi.HTTPException(
@ -832,3 +840,74 @@ async def get_agent(
raise backend.server.v2.store.exceptions.DatabaseError( raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to fetch agent" "Failed to fetch agent"
) from e ) from e
async def review_store_submission(
store_listing_version_id: str, is_approved: bool, comments: str, reviewer_id: str
) -> prisma.models.StoreListingSubmission:
"""Review a store listing submission."""
try:
store_listing_version = (
await prisma.models.StoreListingVersion.prisma().find_unique(
where={"id": store_listing_version_id},
include={"StoreListing": True},
)
)
if not store_listing_version or not store_listing_version.StoreListing:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing version {store_listing_version_id} not found",
)
status = (
prisma.enums.SubmissionStatus.APPROVED
if is_approved
else prisma.enums.SubmissionStatus.REJECTED
)
create_data = prisma.types.StoreListingSubmissionCreateInput(
StoreListingVersion={"connect": {"id": store_listing_version_id}},
Status=status,
reviewComments=comments,
Reviewer={"connect": {"id": reviewer_id}},
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
createdAt=datetime.now(),
updatedAt=datetime.now(),
)
update_data = prisma.types.StoreListingSubmissionUpdateInput(
Status=status,
reviewComments=comments,
Reviewer={"connect": {"id": reviewer_id}},
StoreListing={"connect": {"id": store_listing_version.StoreListing.id}},
updatedAt=datetime.now(),
)
if is_approved:
await prisma.models.StoreListing.prisma().update(
where={"id": store_listing_version.StoreListing.id},
data={"isApproved": True},
)
submission = await prisma.models.StoreListingSubmission.prisma().upsert(
where={"storeListingVersionId": store_listing_version_id},
data=prisma.types.StoreListingSubmissionUpsertInput(
create=create_data,
update=update_data,
),
)
if not submission:
raise fastapi.HTTPException(
status_code=404,
detail=f"Store listing submission {store_listing_version_id} not found",
)
return submission
except Exception as e:
logger.error(f"Error reviewing store submission: {str(e)}")
raise backend.server.v2.store.exceptions.DatabaseError(
"Failed to review store submission"
) from e

View File

@ -115,6 +115,7 @@ class StoreSubmission(pydantic.BaseModel):
status: prisma.enums.SubmissionStatus status: prisma.enums.SubmissionStatus
runs: int runs: int
rating: float rating: float
store_listing_version_id: str | None = None
class StoreSubmissionsResponse(pydantic.BaseModel): class StoreSubmissionsResponse(pydantic.BaseModel):
@ -151,3 +152,9 @@ class StoreReviewCreate(pydantic.BaseModel):
store_listing_version_id: str store_listing_version_id: str
score: int score: int
comments: str | None = None comments: str | None = None
class ReviewSubmissionRequest(pydantic.BaseModel):
store_listing_version_id: str
isApproved: bool
comments: str

View File

@ -642,3 +642,33 @@ async def download_agent_file(
return fastapi.responses.FileResponse( return fastapi.responses.FileResponse(
tmp_file.name, filename=file_name, media_type="application/json" tmp_file.name, filename=file_name, media_type="application/json"
) )
@router.post(
"/submissions/review/{store_listing_version_id}",
tags=["store", "private"],
)
async def review_submission(
request: backend.server.v2.store.model.ReviewSubmissionRequest,
user: typing.Annotated[
autogpt_libs.auth.models.User,
fastapi.Depends(autogpt_libs.auth.depends.requires_admin_user),
],
):
# Proceed with the review submission logic
try:
submission = await backend.server.v2.store.db.review_store_submission(
store_listing_version_id=request.store_listing_version_id,
is_approved=request.isApproved,
comments=request.comments,
reviewer_id=user.user_id,
)
return submission
except Exception:
logger.exception("Exception occurred whilst reviewing store submission")
return fastapi.responses.JSONResponse(
status_code=500,
content={
"detail": "An error occurred while reviewing the store submission"
},
)

View File

@ -253,7 +253,7 @@ async def block_autogen_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id) test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"input": "Write me a block that writes a string into a file."} input_data = {"input": "Write me a block that writes a string into a file."}
response = await server.agent_server.test_execute_graph( response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id test_graph.id, test_graph.version, input_data, test_user.id
) )
print(response) print(response)
result = await wait_execution( result = await wait_execution(

View File

@ -157,7 +157,7 @@ async def reddit_marketing_agent():
test_graph = await create_graph(create_test_graph(), user_id=test_user.id) test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
input_data = {"subreddit": "AutoGPT"} input_data = {"subreddit": "AutoGPT"}
response = await server.agent_server.test_execute_graph( response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id test_graph.id, test_graph.version, input_data, test_user.id
) )
print(response) print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120) result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)

View File

@ -8,12 +8,19 @@ from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution from backend.util.test import SpinTestServer, wait_execution
async def create_test_user() -> User: async def create_test_user(alt_user: bool = False) -> User:
test_user_data = { if alt_user:
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1", test_user_data = {
"email": "testuser#example.com", "sub": "3e53486c-cf57-477e-ba2a-cb02dc828e1b",
"name": "Test User", "email": "testuser2#example.com",
} "name": "Test User 2",
}
else:
test_user_data = {
"sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
"email": "testuser#example.com",
"name": "Test User",
}
user = await get_or_create_user(test_user_data) user = await get_or_create_user(test_user_data)
return user return user
@ -79,7 +86,7 @@ async def sample_agent():
test_graph = await create_graph(create_test_graph(), test_user.id) test_graph = await create_graph(create_test_graph(), test_user.id)
input_data = {"input_1": "Hello", "input_2": "World"} input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.test_execute_graph( response = await server.agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id test_graph.id, test_graph.version, input_data, test_user.id
) )
print(response) print(response)
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10) result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)

View File

@ -153,6 +153,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The name of the Google Cloud Storage bucket for media files", description="The name of the Google Cloud Storage bucket for media files",
) )
reddit_user_agent: str = Field(
default="AutoGPT:1.0 (by /u/autogpt)",
description="The user agent for the Reddit API",
)
scheduler_db_pool_size: int = Field( scheduler_db_pool_size: int = Field(
default=3, default=3,
description="The pool size for the scheduler database connection pool", description="The pool size for the scheduler database connection pool",
@ -276,8 +281,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
reddit_client_id: str = Field(default="", description="Reddit client ID") reddit_client_id: str = Field(default="", description="Reddit client ID")
reddit_client_secret: str = Field(default="", description="Reddit client secret") reddit_client_secret: str = Field(default="", description="Reddit client secret")
reddit_username: str = Field(default="", description="Reddit username")
reddit_password: str = Field(default="", description="Reddit password")
openweathermap_api_key: str = Field( openweathermap_api_key: str = Field(
default="", description="OpenWeatherMap API key" default="", description="OpenWeatherMap API key"

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "AgentPreset" ADD COLUMN "isDeleted" BOOLEAN NOT NULL DEFAULT false;

View File

@ -0,0 +1,46 @@
/*
Warnings:
- You are about to drop the `UserAgent` table. If the table is not empty, all the data it contains will be lost.
*/
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentId_agentVersion_fkey";
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_agentPresetId_fkey";
-- DropForeignKey
ALTER TABLE "UserAgent" DROP CONSTRAINT "UserAgent_userId_fkey";
-- DropTable
DROP TABLE "UserAgent";
-- CreateTable
CREATE TABLE "LibraryAgent" (
"id" TEXT NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"userId" TEXT NOT NULL,
"agentId" TEXT NOT NULL,
"agentVersion" INTEGER NOT NULL,
"agentPresetId" TEXT,
"isFavorite" BOOLEAN NOT NULL DEFAULT false,
"isCreatedByUser" BOOLEAN NOT NULL DEFAULT false,
"isArchived" BOOLEAN NOT NULL DEFAULT false,
"isDeleted" BOOLEAN NOT NULL DEFAULT false,
CONSTRAINT "LibraryAgent_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LibraryAgent_userId_idx" ON "LibraryAgent"("userId");
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentId_agentVersion_fkey" FOREIGN KEY ("agentId", "agentVersion") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
-- AddForeignKey
ALTER TABLE "LibraryAgent" ADD CONSTRAINT "LibraryAgent_agentPresetId_fkey" FOREIGN KEY ("agentPresetId") REFERENCES "AgentPreset"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "LibraryAgent" ADD COLUMN "useGraphIsActiveVersion" BOOLEAN NOT NULL DEFAULT false;

View File

@ -0,0 +1,29 @@
/*
Warnings:
- A unique constraint covering the columns `[agentId]` on the table `StoreListing` will be added. If there are existing duplicate values, this will fail.
*/
-- DropIndex
DROP INDEX "StoreListing_agentId_idx";
-- DropIndex
DROP INDEX "StoreListing_isApproved_idx";
-- DropIndex
DROP INDEX "StoreListingVersion_agentId_agentVersion_isApproved_idx";
-- CreateIndex
CREATE INDEX "StoreListing_agentId_owningUserId_idx" ON "StoreListing"("agentId", "owningUserId");
-- CreateIndex
CREATE INDEX "StoreListing_isDeleted_isApproved_idx" ON "StoreListing"("isDeleted", "isApproved");
-- CreateIndex
CREATE INDEX "StoreListing_isDeleted_idx" ON "StoreListing"("isDeleted");
-- CreateIndex
CREATE UNIQUE INDEX "StoreListing_agentId_key" ON "StoreListing"("agentId");
-- CreateIndex
CREATE INDEX "StoreListingVersion_agentId_agentVersion_isDeleted_idx" ON "StoreListingVersion"("agentId", "agentVersion", "isDeleted");

View File

@ -30,7 +30,7 @@ model User {
CreditTransaction CreditTransaction[] CreditTransaction CreditTransaction[]
AgentPreset AgentPreset[] AgentPreset AgentPreset[]
UserAgent UserAgent[] LibraryAgent LibraryAgent[]
Profile Profile[] Profile Profile[]
StoreListing StoreListing[] StoreListing StoreListing[]
@ -65,7 +65,7 @@ model AgentGraph {
AgentGraphExecution AgentGraphExecution[] AgentGraphExecution AgentGraphExecution[]
AgentPreset AgentPreset[] AgentPreset AgentPreset[]
UserAgent UserAgent[] LibraryAgent LibraryAgent[]
StoreListing StoreListing[] StoreListing StoreListing[]
StoreListingVersion StoreListingVersion? StoreListingVersion StoreListingVersion?
@ -103,15 +103,17 @@ model AgentPreset {
Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade) Agent AgentGraph @relation(fields: [agentId, agentVersion], references: [id, version], onDelete: Cascade)
InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData") InputPresets AgentNodeExecutionInputOutput[] @relation("AgentPresetsInputData")
UserAgents UserAgent[] LibraryAgents LibraryAgent[]
AgentExecution AgentGraphExecution[] AgentExecution AgentGraphExecution[]
isDeleted Boolean @default(false)
@@index([userId]) @@index([userId])
} }
// For the library page // For the library page
// It is a user controlled list of agents, that they will see in there library // It is a user controlled list of agents, that they will see in there library
model UserAgent { model LibraryAgent {
id String @id @default(uuid()) id String @id @default(uuid())
createdAt DateTime @default(now()) createdAt DateTime @default(now())
updatedAt DateTime @default(now()) @updatedAt updatedAt DateTime @default(now()) @updatedAt
@ -126,6 +128,8 @@ model UserAgent {
agentPresetId String? agentPresetId String?
AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id]) AgentPreset AgentPreset? @relation(fields: [agentPresetId], references: [id])
useGraphIsActiveVersion Boolean @default(false)
isFavorite Boolean @default(false) isFavorite Boolean @default(false)
isCreatedByUser Boolean @default(false) isCreatedByUser Boolean @default(false)
isArchived Boolean @default(false) isArchived Boolean @default(false)
@ -235,7 +239,7 @@ model AgentGraphExecution {
AgentNodeExecutions AgentNodeExecution[] AgentNodeExecutions AgentNodeExecution[]
// Link to User model // Link to User model -- Executed by this user
userId String userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade) user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@ -443,6 +447,8 @@ view Creator {
agent_rating Float agent_rating Float
agent_runs Int agent_runs Int
is_featured Boolean is_featured Boolean
// Index or unique are not applied to views
} }
view StoreAgent { view StoreAgent {
@ -465,11 +471,7 @@ view StoreAgent {
rating Float rating Float
versions String[] versions String[]
@@unique([creator_username, slug]) // Index or unique are not applied to views
@@index([creator_username])
@@index([featured])
@@index([categories])
@@index([storeListingVersionId])
} }
view StoreSubmission { view StoreSubmission {
@ -487,7 +489,7 @@ view StoreSubmission {
agent_id String agent_id String
agent_version Int agent_version Int
@@index([user_id]) // Index or unique are not applied to views
} }
model StoreListing { model StoreListing {
@ -510,9 +512,13 @@ model StoreListing {
StoreListingVersions StoreListingVersion[] StoreListingVersions StoreListingVersion[]
StoreListingSubmission StoreListingSubmission[] StoreListingSubmission StoreListingSubmission[]
@@index([isApproved]) // Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
@@index([agentId]) @@unique([agentId])
@@index([agentId, owningUserId])
@@index([owningUserId]) @@index([owningUserId])
// Used in the view query
@@index([isDeleted, isApproved])
@@index([isDeleted])
} }
model StoreListingVersion { model StoreListingVersion {
@ -553,7 +559,7 @@ model StoreListingVersion {
StoreListingReview StoreListingReview[] StoreListingReview StoreListingReview[]
@@unique([agentId, agentVersion]) @@unique([agentId, agentVersion])
@@index([agentId, agentVersion, isApproved]) @@index([agentId, agentVersion, isDeleted])
} }
model StoreListingReview { model StoreListingReview {

View File

@ -1,8 +1,11 @@
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
import autogpt_libs.auth.models
import fastapi.exceptions
import pytest import pytest
import backend.server.v2.store.model
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
from backend.data.block import BlockSchema from backend.data.block import BlockSchema
from backend.data.graph import Graph, Link, Node from backend.data.graph import Graph, Link, Node
@ -202,3 +205,92 @@ async def test_clean_graph(server: SpinTestServer):
n for n in created_graph.nodes if n.block_id == AgentInputBlock().id n for n in created_graph.nodes if n.block_id == AgentInputBlock().id
) )
assert input_node.input_default["value"] == "" assert input_node.input_default["value"] == ""
@pytest.mark.asyncio(scope="session")
async def test_access_store_listing_graph(server: SpinTestServer):
"""
Test the access of a store listing graph.
"""
graph = Graph(
id="test_clean_graph",
name="Test Clean Graph",
description="Test graph cleaning",
nodes=[
Node(
id="input_node",
block_id=AgentInputBlock().id,
input_default={
"name": "test_input",
"value": "test value",
"description": "Test input description",
},
),
],
links=[],
)
# Create graph and get model
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=created_graph.id,
agent_version=created_graph.version,
slug="test-slug",
name="Test name",
sub_heading="Test sub heading",
video_url=None,
image_urls=[],
description="Test description",
categories=[],
)
# First we check the graph an not be accessed by a different user
with pytest.raises(fastapi.exceptions.HTTPException) as exc_info:
await server.agent_server.test_get_graph(
created_graph.id,
created_graph.version,
"3e53486c-cf57-477e-ba2a-cb02dc828e1b",
)
assert exc_info.value.status_code == 404
assert "Graph" in str(exc_info.value.detail)
# Now we create a store listing
store_listing = await server.agent_server.test_create_store_listing(
store_submission_request, DEFAULT_USER_ID
)
if isinstance(store_listing, fastapi.responses.JSONResponse):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
else None
)
assert slv_id is not None
admin = autogpt_libs.auth.models.User(
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
role="admin",
email="admin@example.com",
phone_number="1234567890",
)
await server.agent_server.test_review_store_listing(
backend.server.v2.store.model.ReviewSubmissionRequest(
store_listing_version_id=slv_id,
isApproved=True,
comments="Test comments",
),
admin,
)
# Now we check the graph can be accessed by a user that does not own the graph
got_graph = await server.agent_server.test_get_graph(
created_graph.id, created_graph.version, "3e53486c-cf57-477e-ba2a-cb02dc828e1b"
)
assert got_graph is not None

View File

@ -1,9 +1,13 @@
import logging import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest import pytest
from prisma.models import User from prisma.models import User
from backend.blocks.basic import FindInDictionaryBlock, StoreValueBlock import backend.server.v2.library.model
import backend.server.v2.store.model
from backend.blocks.basic import AgentInputBlock, FindInDictionaryBlock, StoreValueBlock
from backend.blocks.maths import CalculatorBlock, Operation from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph from backend.data import execution, graph
from backend.server.model import CreateGraph from backend.server.model import CreateGraph
@ -31,7 +35,7 @@ async def execute_graph(
# --- Test adding new executions --- # # --- Test adding new executions --- #
response = await agent_server.test_execute_graph( response = await agent_server.test_execute_graph(
test_graph.id, input_data, test_user.id test_graph.id, test_graph.version, input_data, test_user.id
) )
graph_exec_id = response["id"] graph_exec_id = response["id"]
logger.info(f"Created execution with ID: {graph_exec_id}") logger.info(f"Created execution with ID: {graph_exec_id}")
@ -287,3 +291,255 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
assert exec_data.status == execution.ExecutionStatus.COMPLETED assert exec_data.status == execution.ExecutionStatus.COMPLETED
assert exec_data.output_data == {"result": [9]} assert exec_data.output_data == {"result": [9]}
logger.info("Completed test_static_input_link_on_graph") logger.info("Completed test_static_input_link_on_graph")
@pytest.mark.asyncio(scope="session")
async def test_execute_preset(server: SpinTestServer):
"""
Test executing a preset.
This test ensures that:
1. A preset can be successfully executed
2. The execution results are correct
Args:
server (SpinTestServer): The test server instance.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["World"]}
@pytest.mark.asyncio(scope="session")
async def test_execute_preset_with_clash(server: SpinTestServer):
"""
Test executing a preset with clashing input data.
"""
# Create test graph and user
nodes = [
graph.Node( # 0
block_id=AgentInputBlock().id,
input_default={"name": "dictionary"},
),
graph.Node( # 1
block_id=AgentInputBlock().id,
input_default={"name": "selected_value"},
),
graph.Node( # 2
block_id=StoreValueBlock().id,
input_default={"input": {"key1": "Hi", "key2": "Everyone"}},
),
graph.Node( # 3
block_id=FindInDictionaryBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="result",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[3].id,
source_name="result",
sink_name="key",
),
graph.Link(
source_id=nodes[2].id,
sink_id=nodes[3].id,
source_name="output",
sink_name="input",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_user = await create_test_user()
test_graph = await create_graph(server, test_graph, test_user)
# Create preset with initial values
preset = backend.server.v2.library.model.CreateLibraryAgentPresetRequest(
name="Test Preset With Clash",
description="Test preset with clashing input values",
agent_id=test_graph.id,
agent_version=test_graph.version,
inputs={
"dictionary": {"key1": "Hello", "key2": "World"},
"selected_value": "key2",
},
is_active=True,
)
created_preset = await server.agent_server.test_create_preset(preset, test_user.id)
# Execute preset with overriding values
result = await server.agent_server.test_execute_preset(
graph_id=test_graph.id,
graph_version=test_graph.version,
preset_id=created_preset.id,
node_input={"selected_value": "key1"},
user_id=test_user.id,
)
# Verify execution
assert result is not None
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[3].status == execution.ExecutionStatus.COMPLETED
assert executions[3].output_data == {"output": ["Hello"]}
@pytest.mark.asyncio(scope="session")
async def test_store_listing_graph(server: SpinTestServer):
logger.info("Starting test_agent_execution")
test_user = await create_test_user()
test_graph = await create_graph(server, create_test_graph(), test_user)
store_submission_request = backend.server.v2.store.model.StoreSubmissionRequest(
agent_id=test_graph.id,
agent_version=test_graph.version,
slug="test-slug",
name="Test name",
sub_heading="Test sub heading",
video_url=None,
image_urls=[],
description="Test description",
categories=[],
)
store_listing = await server.agent_server.test_create_store_listing(
store_submission_request, test_user.id
)
if isinstance(store_listing, fastapi.responses.JSONResponse):
assert False, "Failed to create store listing"
slv_id = (
store_listing.store_listing_version_id
if store_listing.store_listing_version_id is not None
else None
)
assert slv_id is not None
admin = autogpt_libs.auth.models.User(
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1b",
role="admin",
email="admin@example.com",
phone_number="1234567890",
)
await server.agent_server.test_review_store_listing(
backend.server.v2.store.model.ReviewSubmissionRequest(
store_listing_version_id=slv_id,
isApproved=True,
comments="Test comments",
),
admin,
)
alt_test_user = await create_test_user(alt_user=True)
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server,
test_graph,
alt_test_user,
data,
4,
)
await assert_sample_graph_executions(
server.agent_server, test_graph, alt_test_user, graph_exec_id
)
logger.info("Completed test_agent_execution")

View File

@ -140,10 +140,10 @@ async def main():
print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents") print(f"Inserting {NUM_USERS * MAX_AGENTS_PER_USER} user agents")
for user in users: for user in users:
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER) num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
for _ in range(num_agents): # Create 1 UserAgent per user for _ in range(num_agents): # Create 1 LibraryAgent per user
graph = random.choice(agent_graphs) graph = random.choice(agent_graphs)
preset = random.choice(agent_presets) preset = random.choice(agent_presets)
user_agent = await db.useragent.create( user_agent = await db.libraryagent.create(
data={ data={
"userId": user.id, "userId": user.id,
"agentId": graph.id, "agentId": graph.id,

View File

@ -123,14 +123,22 @@ export default function PrivatePage() {
const allCredentials = providers const allCredentials = providers
? Object.values(providers).flatMap((provider) => ? Object.values(providers).flatMap((provider) =>
[...provider.savedOAuthCredentials, ...provider.savedApiKeys] [
...provider.savedOAuthCredentials,
...provider.savedApiKeys,
...provider.savedUserPasswordCredentials,
]
.filter((cred) => !hiddenCredentials.includes(cred.id)) .filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({ .map((credentials) => ({
...credentials, ...credentials,
provider: provider.provider, provider: provider.provider,
providerName: provider.providerName, providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider], ProviderIcon: providerIcons[provider.provider],
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type], TypeIcon: {
oauth2: IconUser,
api_key: IconKey,
user_password: IconKey,
}[credentials.type],
})), })),
) )
: []; : [];
@ -175,6 +183,7 @@ export default function PrivatePage() {
{ {
oauth2: "OAuth2 credentials", oauth2: "OAuth2 credentials",
api_key: "API key", api_key: "API key",
user_password: "User password",
}[cred.type] }[cred.type]
}{" "} }{" "}
- <code>{cred.id}</code> - <code>{cred.id}</code>

View File

@ -123,14 +123,22 @@ export default function PrivatePage() {
const allCredentials = providers const allCredentials = providers
? Object.values(providers).flatMap((provider) => ? Object.values(providers).flatMap((provider) =>
[...provider.savedOAuthCredentials, ...provider.savedApiKeys] [
...provider.savedOAuthCredentials,
...provider.savedApiKeys,
...provider.savedUserPasswordCredentials,
]
.filter((cred) => !hiddenCredentials.includes(cred.id)) .filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({ .map((credentials) => ({
...credentials, ...credentials,
provider: provider.provider, provider: provider.provider,
providerName: provider.providerName, providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider], ProviderIcon: providerIcons[provider.provider],
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type], TypeIcon: {
oauth2: IconUser,
api_key: IconKey,
user_password: IconKey,
}[credentials.type],
})), })),
) )
: []; : [];

View File

@ -73,7 +73,9 @@ export const providerIcons: Record<
open_router: fallbackIcon, open_router: fallbackIcon,
pinecone: fallbackIcon, pinecone: fallbackIcon,
slant3d: fallbackIcon, slant3d: fallbackIcon,
smtp: fallbackIcon,
replicate: fallbackIcon, replicate: fallbackIcon,
reddit: fallbackIcon,
fal: fallbackIcon, fal: fallbackIcon,
revid: fallbackIcon, revid: fallbackIcon,
twitter: FaTwitter, twitter: FaTwitter,
@ -105,6 +107,10 @@ export const CredentialsInput: FC<{
const credentials = useCredentials(selfKey); const credentials = useCredentials(selfKey);
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] = const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
useState(false); useState(false);
const [
isUserPasswordCredentialsModalOpen,
setUserPasswordCredentialsModalOpen,
] = useState(false);
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false); const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
const [oAuthPopupController, setOAuthPopupController] = const [oAuthPopupController, setOAuthPopupController] =
useState<AbortController | null>(null); useState<AbortController | null>(null);
@ -120,8 +126,10 @@ export const CredentialsInput: FC<{
providerName, providerName,
supportsApiKey, supportsApiKey,
supportsOAuth2, supportsOAuth2,
supportsUserPassword,
savedApiKeys, savedApiKeys,
savedOAuthCredentials, savedOAuthCredentials,
savedUserPasswordCredentials,
oAuthCallback, oAuthCallback,
} = credentials; } = credentials;
@ -235,6 +243,17 @@ export const CredentialsInput: FC<{
providerName={providerName} providerName={providerName}
/> />
)} )}
{supportsUserPassword && (
<UserPasswordCredentialsModal
credentialsFieldName={selfKey}
open={isUserPasswordCredentialsModalOpen}
onClose={() => setUserPasswordCredentialsModalOpen(false)}
onCredentialsCreate={(creds) => {
onSelectCredentials(creds);
setUserPasswordCredentialsModalOpen(false);
}}
/>
)}
</> </>
); );
@ -243,13 +262,18 @@ export const CredentialsInput: FC<{
selectedCredentials && selectedCredentials &&
!savedApiKeys !savedApiKeys
.concat(savedOAuthCredentials) .concat(savedOAuthCredentials)
.concat(savedUserPasswordCredentials)
.some((c) => c.id === selectedCredentials.id) .some((c) => c.id === selectedCredentials.id)
) { ) {
onSelectCredentials(undefined); onSelectCredentials(undefined);
} }
// No saved credentials yet // No saved credentials yet
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) { if (
savedApiKeys.length === 0 &&
savedOAuthCredentials.length === 0 &&
savedUserPasswordCredentials.length === 0
) {
return ( return (
<> <>
<div className="mb-2 flex gap-1"> <div className="mb-2 flex gap-1">
@ -271,6 +295,12 @@ export const CredentialsInput: FC<{
Enter API key Enter API key
</Button> </Button>
)} )}
{supportsUserPassword && (
<Button onClick={() => setUserPasswordCredentialsModalOpen(true)}>
<ProviderIcon className="mr-2 h-4 w-4" />
Enter user password
</Button>
)}
</div> </div>
{modals} {modals}
{oAuthError && ( {oAuthError && (
@ -280,12 +310,29 @@ export const CredentialsInput: FC<{
); );
} }
const singleCredential = const getCredentialCounts = () => ({
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0 apiKeys: savedApiKeys.length,
? savedApiKeys[0] oauth: savedOAuthCredentials.length,
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0 userPass: savedUserPasswordCredentials.length,
? savedOAuthCredentials[0] });
: null;
const getSingleCredential = () => {
const counts = getCredentialCounts();
const totalCredentials = Object.values(counts).reduce(
(sum, count) => sum + count,
0,
);
if (totalCredentials !== 1) return null;
if (counts.apiKeys === 1) return savedApiKeys[0];
if (counts.oauth === 1) return savedOAuthCredentials[0];
if (counts.userPass === 1) return savedUserPasswordCredentials[0];
return null;
};
const singleCredential = getSingleCredential();
if (singleCredential) { if (singleCredential) {
if (!selectedCredentials) { if (!selectedCredentials) {
@ -309,6 +356,7 @@ export const CredentialsInput: FC<{
} else { } else {
const selectedCreds = savedApiKeys const selectedCreds = savedApiKeys
.concat(savedOAuthCredentials) .concat(savedOAuthCredentials)
.concat(savedUserPasswordCredentials)
.find((c) => c.id == newValue)!; .find((c) => c.id == newValue)!;
onSelectCredentials({ onSelectCredentials({
@ -347,6 +395,13 @@ export const CredentialsInput: FC<{
{credentials.title} {credentials.title}
</SelectItem> </SelectItem>
))} ))}
{savedUserPasswordCredentials.map((credentials, index) => (
<SelectItem key={index} value={credentials.id}>
<ProviderIcon className="mr-2 inline h-4 w-4" />
<IconUserPlus className="mr-1.5 inline" />
{credentials.title}
</SelectItem>
))}
<SelectSeparator /> <SelectSeparator />
{supportsOAuth2 && ( {supportsOAuth2 && (
<SelectItem value="sign-in"> <SelectItem value="sign-in">
@ -360,6 +415,12 @@ export const CredentialsInput: FC<{
Add new API key Add new API key
</SelectItem> </SelectItem>
)} )}
{supportsUserPassword && (
<SelectItem value="add-user-password">
<IconUserPlus className="mr-1.5 inline" />
Add new user password
</SelectItem>
)}
</SelectContent> </SelectContent>
</Select> </Select>
{modals} {modals}
@ -506,6 +567,128 @@ export const APIKeyCredentialsModal: FC<{
); );
}; };
export const UserPasswordCredentialsModal: FC<{
credentialsFieldName: string;
open: boolean;
onClose: () => void;
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
}> = ({ credentialsFieldName, open, onClose, onCredentialsCreate }) => {
const credentials = useCredentials(credentialsFieldName);
const formSchema = z.object({
username: z.string().min(1, "Username is required"),
password: z.string().min(1, "Password is required"),
title: z.string().min(1, "Name is required"),
});
const form = useForm<z.infer<typeof formSchema>>({
resolver: zodResolver(formSchema),
defaultValues: {
username: "",
password: "",
title: "",
},
});
if (
!credentials ||
credentials.isLoading ||
!credentials.supportsUserPassword
) {
return null;
}
const { schema, provider, providerName, createUserPasswordCredentials } =
credentials;
async function onSubmit(values: z.infer<typeof formSchema>) {
const newCredentials = await createUserPasswordCredentials({
username: values.username,
password: values.password,
title: values.title,
});
onCredentialsCreate({
provider,
id: newCredentials.id,
type: "user_password",
title: newCredentials.title,
});
}
return (
<Dialog
open={open}
onOpenChange={(open) => {
if (!open) onClose();
}}
>
<DialogContent>
<DialogHeader>
<DialogTitle>Add new user password for {providerName}</DialogTitle>
</DialogHeader>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="username"
render={({ field }) => (
<FormItem>
<FormLabel>Username</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter username..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="password"
render={({ field }) => (
<FormItem>
<FormLabel>Password</FormLabel>
<FormControl>
<Input
type="password"
placeholder="Enter password..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="title"
render={({ field }) => (
<FormItem>
<FormLabel>Name</FormLabel>
<FormControl>
<Input
type="text"
placeholder="Enter a name for this user password..."
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Button type="submit" className="w-full">
Save & use this user password
</Button>
</form>
</Form>
</DialogContent>
</Dialog>
);
};
export const OAuth2FlowWaitingModal: FC<{ export const OAuth2FlowWaitingModal: FC<{
open: boolean; open: boolean;
onClose: () => void; onClose: () => void;

View File

@ -5,6 +5,7 @@ import {
CredentialsMetaResponse, CredentialsMetaResponse,
CredentialsProviderName, CredentialsProviderName,
PROVIDER_NAMES, PROVIDER_NAMES,
UserPasswordCredentials,
} from "@/lib/autogpt-server-api"; } from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context"; import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { createContext, useCallback, useEffect, useState } from "react"; import { createContext, useCallback, useEffect, useState } from "react";
@ -20,10 +21,13 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
discord: "Discord", discord: "Discord",
d_id: "D-ID", d_id: "D-ID",
e2b: "E2B", e2b: "E2B",
exa: "Exa",
fal: "FAL",
github: "GitHub", github: "GitHub",
google: "Google", google: "Google",
google_maps: "Google Maps", google_maps: "Google Maps",
groq: "Groq", groq: "Groq",
hubspot: "Hubspot",
ideogram: "Ideogram", ideogram: "Ideogram",
jina: "Jina", jina: "Jina",
medium: "Medium", medium: "Medium",
@ -35,13 +39,12 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
open_router: "Open Router", open_router: "Open Router",
pinecone: "Pinecone", pinecone: "Pinecone",
slant3d: "Slant3D", slant3d: "Slant3D",
smtp: "SMTP",
reddit: "Reddit",
replicate: "Replicate", replicate: "Replicate",
fal: "FAL",
revid: "Rev.ID", revid: "Rev.ID",
twitter: "Twitter", twitter: "Twitter",
unreal_speech: "Unreal Speech", unreal_speech: "Unreal Speech",
exa: "Exa",
hubspot: "Hubspot",
} as const; } as const;
// --8<-- [end:CredentialsProviderNames] // --8<-- [end:CredentialsProviderNames]
@ -50,11 +53,17 @@ type APIKeyCredentialsCreatable = Omit<
"id" | "provider" | "type" "id" | "provider" | "type"
>; >;
type UserPasswordCredentialsCreatable = Omit<
UserPasswordCredentials,
"id" | "provider" | "type"
>;
export type CredentialsProviderData = { export type CredentialsProviderData = {
provider: CredentialsProviderName; provider: CredentialsProviderName;
providerName: string; providerName: string;
savedApiKeys: CredentialsMetaResponse[]; savedApiKeys: CredentialsMetaResponse[];
savedOAuthCredentials: CredentialsMetaResponse[]; savedOAuthCredentials: CredentialsMetaResponse[];
savedUserPasswordCredentials: CredentialsMetaResponse[];
oAuthCallback: ( oAuthCallback: (
code: string, code: string,
state_token: string, state_token: string,
@ -62,6 +71,9 @@ export type CredentialsProviderData = {
createAPIKeyCredentials: ( createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable, credentials: APIKeyCredentialsCreatable,
) => Promise<CredentialsMetaResponse>; ) => Promise<CredentialsMetaResponse>;
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) => Promise<CredentialsMetaResponse>;
deleteCredentials: ( deleteCredentials: (
id: string, id: string,
force?: boolean, force?: boolean,
@ -106,6 +118,11 @@ export default function CredentialsProvider({
...updatedProvider.savedOAuthCredentials, ...updatedProvider.savedOAuthCredentials,
credentials, credentials,
]; ];
} else if (credentials.type === "user_password") {
updatedProvider.savedUserPasswordCredentials = [
...updatedProvider.savedUserPasswordCredentials,
credentials,
];
} }
return { return {
@ -147,6 +164,22 @@ export default function CredentialsProvider({
[api, addCredentials], [api, addCredentials],
); );
/** Wraps `BackendAPI.createUserPasswordCredentials`, and adds the result to the internal credentials store. */
const createUserPasswordCredentials = useCallback(
async (
provider: CredentialsProviderName,
credentials: UserPasswordCredentialsCreatable,
): Promise<CredentialsMetaResponse> => {
const credsMeta = await api.createUserPasswordCredentials({
provider,
...credentials,
});
addCredentials(provider, credsMeta);
return credsMeta;
},
[api, addCredentials],
);
/** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */ /** Wraps `BackendAPI.deleteCredentials`, and removes the credentials from the internal store. */
const deleteCredentials = useCallback( const deleteCredentials = useCallback(
async ( async (
@ -171,7 +204,10 @@ export default function CredentialsProvider({
updatedProvider.savedOAuthCredentials.filter( updatedProvider.savedOAuthCredentials.filter(
(cred) => cred.id !== id, (cred) => cred.id !== id,
); );
updatedProvider.savedUserPasswordCredentials =
updatedProvider.savedUserPasswordCredentials.filter(
(cred) => cred.id !== id,
);
return { return {
...prev, ...prev,
[provider]: updatedProvider, [provider]: updatedProvider,
@ -190,12 +226,18 @@ export default function CredentialsProvider({
const credentialsByProvider = response.reduce( const credentialsByProvider = response.reduce(
(acc, cred) => { (acc, cred) => {
if (!acc[cred.provider]) { if (!acc[cred.provider]) {
acc[cred.provider] = { oauthCreds: [], apiKeys: [] }; acc[cred.provider] = {
oauthCreds: [],
apiKeys: [],
userPasswordCreds: [],
};
} }
if (cred.type === "oauth2") { if (cred.type === "oauth2") {
acc[cred.provider].oauthCreds.push(cred); acc[cred.provider].oauthCreds.push(cred);
} else if (cred.type === "api_key") { } else if (cred.type === "api_key") {
acc[cred.provider].apiKeys.push(cred); acc[cred.provider].apiKeys.push(cred);
} else if (cred.type === "user_password") {
acc[cred.provider].userPasswordCreds.push(cred);
} }
return acc; return acc;
}, },
@ -204,6 +246,7 @@ export default function CredentialsProvider({
{ {
oauthCreds: CredentialsMetaResponse[]; oauthCreds: CredentialsMetaResponse[];
apiKeys: CredentialsMetaResponse[]; apiKeys: CredentialsMetaResponse[];
userPasswordCreds: CredentialsMetaResponse[];
} }
>, >,
); );
@ -220,6 +263,8 @@ export default function CredentialsProvider({
savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [], savedApiKeys: credentialsByProvider[provider]?.apiKeys ?? [],
savedOAuthCredentials: savedOAuthCredentials:
credentialsByProvider[provider]?.oauthCreds ?? [], credentialsByProvider[provider]?.oauthCreds ?? [],
savedUserPasswordCredentials:
credentialsByProvider[provider]?.userPasswordCreds ?? [],
oAuthCallback: (code: string, state_token: string) => oAuthCallback: (code: string, state_token: string) =>
oAuthCallback( oAuthCallback(
provider as CredentialsProviderName, provider as CredentialsProviderName,
@ -233,6 +278,13 @@ export default function CredentialsProvider({
provider as CredentialsProviderName, provider as CredentialsProviderName,
credentials, credentials,
), ),
createUserPasswordCredentials: (
credentials: UserPasswordCredentialsCreatable,
) =>
createUserPasswordCredentials(
provider as CredentialsProviderName,
credentials,
),
deleteCredentials: (id: string, force: boolean = false) => deleteCredentials: (id: string, force: boolean = false) =>
deleteCredentials( deleteCredentials(
provider as CredentialsProviderName, provider as CredentialsProviderName,
@ -245,7 +297,13 @@ export default function CredentialsProvider({
})); }));
}); });
}); });
}, [api, createAPIKeyCredentials, deleteCredentials, oAuthCallback]); }, [
api,
createAPIKeyCredentials,
createUserPasswordCredentials,
deleteCredentials,
oAuthCallback,
]);
return ( return (
<CredentialsProvidersContext.Provider value={providers}> <CredentialsProvidersContext.Provider value={providers}>

View File

@ -17,12 +17,14 @@ export type CredentialsData =
schema: BlockIOCredentialsSubSchema; schema: BlockIOCredentialsSubSchema;
supportsApiKey: boolean; supportsApiKey: boolean;
supportsOAuth2: boolean; supportsOAuth2: boolean;
supportsUserPassword: boolean;
isLoading: true; isLoading: true;
} }
| (CredentialsProviderData & { | (CredentialsProviderData & {
schema: BlockIOCredentialsSubSchema; schema: BlockIOCredentialsSubSchema;
supportsApiKey: boolean; supportsApiKey: boolean;
supportsOAuth2: boolean; supportsOAuth2: boolean;
supportsUserPassword: boolean;
isLoading: false; isLoading: false;
}); });
@ -72,6 +74,8 @@ export default function useCredentials(
const supportsApiKey = const supportsApiKey =
credentialsSchema.credentials_types.includes("api_key"); credentialsSchema.credentials_types.includes("api_key");
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2"); const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
const supportsUserPassword =
credentialsSchema.credentials_types.includes("user_password");
// No provider means maybe it's still loading // No provider means maybe it's still loading
if (!provider) { if (!provider) {
@ -93,13 +97,17 @@ export default function useCredentials(
) )
: provider.savedOAuthCredentials; : provider.savedOAuthCredentials;
const savedUserPasswordCredentials = provider.savedUserPasswordCredentials;
return { return {
...provider, ...provider,
provider: providerName, provider: providerName,
schema: credentialsSchema, schema: credentialsSchema,
supportsApiKey, supportsApiKey,
supportsOAuth2, supportsOAuth2,
supportsUserPassword,
savedOAuthCredentials, savedOAuthCredentials,
savedUserPasswordCredentials,
isLoading: false, isLoading: false,
}; };
} }

View File

@ -15,7 +15,6 @@ import {
GraphUpdateable, GraphUpdateable,
NodeExecutionResult, NodeExecutionResult,
MyAgentsResponse, MyAgentsResponse,
OAuth2Credentials,
ProfileDetails, ProfileDetails,
User, User,
StoreAgentsResponse, StoreAgentsResponse,
@ -29,6 +28,8 @@ import {
StoreReview, StoreReview,
ScheduleCreatable, ScheduleCreatable,
Schedule, Schedule,
UserPasswordCredentials,
Credentials,
APIKeyPermission, APIKeyPermission,
CreateAPIKeyResponse, CreateAPIKeyResponse,
APIKey, APIKey,
@ -191,7 +192,17 @@ export default class BackendAPI {
return this._request( return this._request(
"POST", "POST",
`/integrations/${credentials.provider}/credentials`, `/integrations/${credentials.provider}/credentials`,
credentials, { ...credentials, type: "api_key" },
);
}
createUserPasswordCredentials(
credentials: Omit<UserPasswordCredentials, "id" | "type">,
): Promise<UserPasswordCredentials> {
return this._request(
"POST",
`/integrations/${credentials.provider}/credentials`,
{ ...credentials, type: "user_password" },
); );
} }
@ -203,10 +214,7 @@ export default class BackendAPI {
); );
} }
getCredentials( getCredentials(provider: string, id: string): Promise<Credentials> {
provider: string,
id: string,
): Promise<APIKeyCredentials | OAuth2Credentials> {
return this._get(`/integrations/${provider}/credentials/${id}`); return this._get(`/integrations/${provider}/credentials/${id}`);
} }

View File

@ -97,7 +97,12 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
default?: boolean; default?: boolean;
}; };
export type CredentialsType = "api_key" | "oauth2"; export type CredentialsType = "api_key" | "oauth2" | "user_password";
export type Credentials =
| APIKeyCredentials
| OAuth2Credentials
| UserPasswordCredentials;
// --8<-- [start:BlockIOCredentialsSubSchema] // --8<-- [start:BlockIOCredentialsSubSchema]
export const PROVIDER_NAMES = { export const PROVIDER_NAMES = {
@ -105,10 +110,13 @@ export const PROVIDER_NAMES = {
D_ID: "d_id", D_ID: "d_id",
DISCORD: "discord", DISCORD: "discord",
E2B: "e2b", E2B: "e2b",
EXA: "exa",
FAL: "fal",
GITHUB: "github", GITHUB: "github",
GOOGLE: "google", GOOGLE: "google",
GOOGLE_MAPS: "google_maps", GOOGLE_MAPS: "google_maps",
GROQ: "groq", GROQ: "groq",
HUBSPOT: "hubspot",
IDEOGRAM: "ideogram", IDEOGRAM: "ideogram",
JINA: "jina", JINA: "jina",
MEDIUM: "medium", MEDIUM: "medium",
@ -120,13 +128,12 @@ export const PROVIDER_NAMES = {
OPEN_ROUTER: "open_router", OPEN_ROUTER: "open_router",
PINECONE: "pinecone", PINECONE: "pinecone",
SLANT3D: "slant3d", SLANT3D: "slant3d",
SMTP: "smtp",
TWITTER: "twitter",
REPLICATE: "replicate", REPLICATE: "replicate",
FAL: "fal", REDDIT: "reddit",
REVID: "revid", REVID: "revid",
UNREAL_SPEECH: "unreal_speech", UNREAL_SPEECH: "unreal_speech",
EXA: "exa",
HUBSPOT: "hubspot",
TWITTER: "twitter",
} as const; } as const;
// --8<-- [end:BlockIOCredentialsSubSchema] // --8<-- [end:BlockIOCredentialsSubSchema]
@ -322,8 +329,15 @@ export type APIKeyCredentials = BaseCredentials & {
expires_at?: number; expires_at?: number;
}; };
export type UserPasswordCredentials = BaseCredentials & {
type: "user_password";
title: string;
username: string;
password: string;
};
/* Mirror of backend/data/integrations.py:Webhook */ /* Mirror of backend/data/integrations.py:Webhook */
type Webhook = { export type Webhook = {
id: string; id: string;
url: string; url: string;
provider: CredentialsProviderName; provider: CredentialsProviderName;