OpenAI assistant fixes (#4969)

This commit is contained in:
Jack Gerrits 2025-01-09 15:06:01 -05:00 committed by GitHub
parent 5b841e26d6
commit 0122d44aa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 45 deletions

View File

@ -78,17 +78,16 @@ dev = [
[tool.ruff]
extend = "../../pyproject.toml"
exclude = ["build", "dist", "src/autogen_core/application/protos", "tests/protos", "samples/protos"]
include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
exclude = ["build", "dist", "src/autogen_core/application/protos", "tests/protos"]
include = ["src/**", "docs/**/*.ipynb", "tests/**"]
[tool.ruff.lint.per-file-ignores]
"samples/**.py" = ["T20"]
"docs/**.ipynb" = ["T20"]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests", "samples"]
exclude = ["src/autogen_core/application/protos", "tests/protos", "samples/protos"]
include = ["src", "tests"]
exclude = ["src/autogen_core/application/protos", "tests/protos"]
reportDeprecated = true
[tool.pytest.ini_options]

View File

@ -3,7 +3,6 @@ import json
import logging
import os
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
@ -19,6 +18,7 @@ from typing import (
cast,
)
import aiofiles
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
@ -33,50 +33,31 @@ from autogen_agentchat.messages import (
ToolCallRequestEvent,
)
from autogen_core import CancellationToken, FunctionCall
from autogen_core.models._model_client import ChatCompletionClient
from autogen_core.models._types import FunctionExecutionResult
from autogen_core.tools import FunctionTool, Tool
_has_openai_dependencies: bool = True
try:
import aiofiles
from openai import NOT_GIVEN
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.shared_params.function_definition import FunctionDefinition
except ImportError:
_has_openai_dependencies = False
if TYPE_CHECKING:
import aiofiles
from openai import NOT_GIVEN, AsyncClient, NotGiven
from openai.pagination import AsyncCursorPage
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types import FileObject
from openai.types.beta import thread_update_params
from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
from openai.types.beta.threads import Message, MessageDeleted, Run
from openai.types.beta.vector_store import VectorStore
from openai.types.shared_params.function_definition import FunctionDefinition
from openai import NOT_GIVEN, AsyncClient, NotGiven
from openai.pagination import AsyncCursorPage
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
from openai.types import FileObject
from openai.types.beta import thread_update_params
from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
from openai.types.beta.file_search_tool_param import FileSearchToolParam
from openai.types.beta.function_tool_param import FunctionToolParam
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
from openai.types.beta.threads import Message, MessageDeleted, Run
from openai.types.beta.vector_store import VectorStore
from openai.types.shared_params.function_definition import FunctionDefinition
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
if not _has_openai_dependencies:
raise RuntimeError(
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
)
schema = tool.schema
parameters: Dict[str, object] = {}
@ -158,10 +139,12 @@ class OpenAIAssistantAgent(BaseChatAgent):
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
# Get response from the assistant
_response = await assistant.on_messages(
response = await assistant.on_messages(
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
)
print(response)
# Clean up resources
await assistant.delete_uploaded_files(cancellation_token)
await assistant.delete_assistant(cancellation_token)
@ -207,9 +190,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
tool_resources: Optional["ToolResources"] = None,
top_p: Optional[float] = None,
) -> None:
if not _has_openai_dependencies:
raise RuntimeError(
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
if isinstance(client, ChatCompletionClient):
raise ValueError(
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
)
super().__init__(name, description)
@ -510,6 +493,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
"""Upload files and return their IDs."""
await self._ensure_initialized()
if isinstance(file_paths, str):
file_paths = [file_paths]
@ -531,6 +516,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
) -> None:
"""Handle file uploads for the code interpreter."""
await self._ensure_initialized()
file_ids = await self._upload_files(file_paths, cancellation_token)
# Update thread with the new files
@ -596,6 +583,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
"""Delete all files that were uploaded by this agent instance."""
await self._ensure_initialized()
for file_id in self._uploaded_file_ids:
try:
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
@ -605,6 +593,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
"""Delete the assistant if it was created by this instance."""
await self._ensure_initialized()
if self._assistant is not None and not self._assistant_id:
try:
await cancellation_token.link_future(
@ -616,6 +605,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
"""Delete the vector store if it was created by this instance."""
await self._ensure_initialized()
if self._vector_store_id is not None:
try:
await cancellation_token.link_future(