mirror of https://github.com/microsoft/autogen.git
OpenAI assistant fixes (#4969)
This commit is contained in:
parent
5b841e26d6
commit
0122d44aa3
|
@ -78,17 +78,16 @@ dev = [
|
|||
|
||||
[tool.ruff]
|
||||
extend = "../../pyproject.toml"
|
||||
exclude = ["build", "dist", "src/autogen_core/application/protos", "tests/protos", "samples/protos"]
|
||||
include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
|
||||
exclude = ["build", "dist", "src/autogen_core/application/protos", "tests/protos"]
|
||||
include = ["src/**", "docs/**/*.ipynb", "tests/**"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"samples/**.py" = ["T20"]
|
||||
"docs/**.ipynb" = ["T20"]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["src", "tests", "samples"]
|
||||
exclude = ["src/autogen_core/application/protos", "tests/protos", "samples/protos"]
|
||||
include = ["src", "tests"]
|
||||
exclude = ["src/autogen_core/application/protos", "tests/protos"]
|
||||
reportDeprecated = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
|
|
@ -3,7 +3,6 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
|
@ -19,6 +18,7 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
import aiofiles
|
||||
from autogen_agentchat import EVENT_LOGGER_NAME
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
from autogen_agentchat.base import Response
|
||||
|
@ -33,50 +33,31 @@ from autogen_agentchat.messages import (
|
|||
ToolCallRequestEvent,
|
||||
)
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.models._model_client import ChatCompletionClient
|
||||
from autogen_core.models._types import FunctionExecutionResult
|
||||
from autogen_core.tools import FunctionTool, Tool
|
||||
|
||||
_has_openai_dependencies: bool = True
|
||||
try:
|
||||
import aiofiles
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
except ImportError:
|
||||
_has_openai_dependencies = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiofiles
|
||||
|
||||
from openai import NOT_GIVEN, AsyncClient, NotGiven
|
||||
from openai.pagination import AsyncCursorPage
|
||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||
from openai.types import FileObject
|
||||
from openai.types.beta import thread_update_params
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
|
||||
from openai.types.beta.threads import Message, MessageDeleted, Run
|
||||
from openai.types.beta.vector_store import VectorStore
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
from openai import NOT_GIVEN, AsyncClient, NotGiven
|
||||
from openai.pagination import AsyncCursorPage
|
||||
from openai.resources.beta.threads import AsyncMessages, AsyncRuns, AsyncThreads
|
||||
from openai.types import FileObject
|
||||
from openai.types.beta import thread_update_params
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from openai.types.beta.assistant_response_format_option_param import AssistantResponseFormatOptionParam
|
||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||
from openai.types.beta.code_interpreter_tool_param import CodeInterpreterToolParam
|
||||
from openai.types.beta.file_search_tool_param import FileSearchToolParam
|
||||
from openai.types.beta.function_tool_param import FunctionToolParam
|
||||
from openai.types.beta.thread import Thread, ToolResources, ToolResourcesCodeInterpreter
|
||||
from openai.types.beta.threads import Message, MessageDeleted, Run
|
||||
from openai.types.beta.vector_store import VectorStore
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
def _convert_tool_to_function_param(tool: Tool) -> "FunctionToolParam":
|
||||
"""Convert an autogen Tool to an OpenAI Assistant function tool parameter."""
|
||||
if not _has_openai_dependencies:
|
||||
raise RuntimeError(
|
||||
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
|
||||
)
|
||||
|
||||
schema = tool.schema
|
||||
parameters: Dict[str, object] = {}
|
||||
|
@ -158,10 +139,12 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
await assistant.on_upload_for_code_interpreter("data.csv", cancellation_token)
|
||||
|
||||
# Get response from the assistant
|
||||
_response = await assistant.on_messages(
|
||||
response = await assistant.on_messages(
|
||||
[TextMessage(source="user", content="Analyze the data in data.csv")], cancellation_token
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
# Clean up resources
|
||||
await assistant.delete_uploaded_files(cancellation_token)
|
||||
await assistant.delete_assistant(cancellation_token)
|
||||
|
@ -207,9 +190,9 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
tool_resources: Optional["ToolResources"] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
if not _has_openai_dependencies:
|
||||
raise RuntimeError(
|
||||
"Missing dependecies for OpenAIAssistantAgent. Please ensure the autogen-ext package was installed with the 'openai' extra."
|
||||
if isinstance(client, ChatCompletionClient):
|
||||
raise ValueError(
|
||||
"Incorrect client passed to OpenAIAssistantAgent. Please use an OpenAI AsyncClient instance instead of an AutoGen ChatCompletionClient instance."
|
||||
)
|
||||
|
||||
super().__init__(name, description)
|
||||
|
@ -510,6 +493,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
|
||||
async def _upload_files(self, file_paths: str | Iterable[str], cancellation_token: CancellationToken) -> List[str]:
|
||||
"""Upload files and return their IDs."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
if isinstance(file_paths, str):
|
||||
file_paths = [file_paths]
|
||||
|
||||
|
@ -531,6 +516,8 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
self, file_paths: str | Iterable[str], cancellation_token: CancellationToken
|
||||
) -> None:
|
||||
"""Handle file uploads for the code interpreter."""
|
||||
await self._ensure_initialized()
|
||||
|
||||
file_ids = await self._upload_files(file_paths, cancellation_token)
|
||||
|
||||
# Update thread with the new files
|
||||
|
@ -596,6 +583,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
|
||||
async def delete_uploaded_files(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete all files that were uploaded by this agent instance."""
|
||||
await self._ensure_initialized()
|
||||
for file_id in self._uploaded_file_ids:
|
||||
try:
|
||||
await cancellation_token.link_future(asyncio.ensure_future(self._client.files.delete(file_id=file_id)))
|
||||
|
@ -605,6 +593,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
|
||||
async def delete_assistant(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the assistant if it was created by this instance."""
|
||||
await self._ensure_initialized()
|
||||
if self._assistant is not None and not self._assistant_id:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
|
@ -616,6 +605,7 @@ class OpenAIAssistantAgent(BaseChatAgent):
|
|||
|
||||
async def delete_vector_store(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Delete the vector store if it was created by this instance."""
|
||||
await self._ensure_initialized()
|
||||
if self._vector_store_id is not None:
|
||||
try:
|
||||
await cancellation_token.link_future(
|
||||
|
|
Loading…
Reference in New Issue