mirror of https://github.com/microsoft/autogen.git
Bugfix/azure ai search embedding (#6248)
## Why are these changes needed? bug fix : add get_embedding() implementation ## Related issue number "Closes #6240 " --> ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [X] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [X] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
b3f59057fa
commit
cc806a57ef
|
@ -166,7 +166,7 @@ def lang_to_cmd(lang: str) -> str:
|
|||
elif shutil.which("powershell") is not None:
|
||||
return "powershell"
|
||||
else:
|
||||
raise ValueError(f"Powershell or pwsh is not installed. Please install one of them.")
|
||||
raise ValueError("Powershell or pwsh is not installed. Please install one of them.")
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {lang}")
|
||||
|
||||
|
|
|
@ -1009,6 +1009,86 @@ class AzureAISearchTool(BaseAzureAISearchTool):
|
|||
finally:
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
"""Generate embedding vector for the query text.
|
||||
|
||||
This method handles generating embeddings for vector search functionality.
|
||||
The embedding provider and model should be specified in the tool configuration.
|
||||
|
||||
Args:
|
||||
query (str): The text to generate embeddings for.
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector as a list of floats.
|
||||
|
||||
Raises:
|
||||
ValueError: If the embedding configuration is missing or invalid.
|
||||
"""
|
||||
embedding_provider = getattr(self.search_config, "embedding_provider", None)
|
||||
embedding_model = getattr(self.search_config, "embedding_model", None)
|
||||
|
||||
if not embedding_provider or not embedding_model:
|
||||
raise ValueError(
|
||||
"To use vector search, you must provide embedding_provider and embedding_model in the configuration."
|
||||
) from None
|
||||
|
||||
if embedding_provider.lower() == "azure_openai":
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from openai import AsyncAzureOpenAI
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Azure OpenAI SDK is required for embedding generation. "
|
||||
"Please install it with: uv add openai azure-identity"
|
||||
) from None
|
||||
|
||||
api_key = None
|
||||
if hasattr(self.search_config, "openai_api_key"):
|
||||
api_key = self.search_config.openai_api_key
|
||||
|
||||
api_version = getattr(self.search_config, "openai_api_version", "2023-05-15")
|
||||
endpoint = getattr(self.search_config, "openai_endpoint", None)
|
||||
|
||||
if not endpoint:
|
||||
raise ValueError("OpenAI endpoint must be provided for Azure OpenAI embeddings") from None
|
||||
|
||||
if api_key:
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint)
|
||||
else:
|
||||
|
||||
def get_token() -> str:
|
||||
credential = DefaultAzureCredential()
|
||||
return credential.get_token("https://cognitiveservices.azure.com/.default").token
|
||||
|
||||
azure_client = AsyncAzureOpenAI(
|
||||
azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint
|
||||
)
|
||||
|
||||
response = await azure_client.embeddings.create(model=embedding_model, input=query)
|
||||
return response.data[0].embedding
|
||||
|
||||
elif embedding_provider.lower() == "openai":
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OpenAI SDK is required for embedding generation. " "Please install it with: uv add openai"
|
||||
) from None
|
||||
|
||||
api_key = None
|
||||
if hasattr(self.search_config, "openai_api_key"):
|
||||
api_key = self.search_config.openai_api_key
|
||||
|
||||
openai_client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
response = await openai_client.embeddings.create(model=embedding_model, input=query)
|
||||
return response.data[0].embedding
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {embedding_provider}. "
|
||||
"Currently supported providers are 'azure_openai' and 'openai'."
|
||||
) from None
|
||||
|
||||
@classmethod
|
||||
def create_hybrid_search(
|
||||
cls,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
"""Tests for the Azure AI Search tool."""
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union, cast
|
||||
from unittest.mock import AsyncMock, patch
|
||||
# pyright: reportPrivateUsage=false
|
||||
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken
|
||||
|
@ -11,7 +13,7 @@ from autogen_ext.tools.azure._ai_search import (
|
|||
SearchQuery,
|
||||
SearchResult,
|
||||
SearchResults,
|
||||
_allow_private_constructor, # pyright: ignore[reportPrivateUsage]
|
||||
_allow_private_constructor,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential, TokenCredential
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
|
@ -40,7 +42,7 @@ async def search_tool() -> AsyncGenerator[AzureAISearchTool, None]:
|
|||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
token = _allow_private_constructor.set(True) # pyright: ignore[reportPrivateUsage]
|
||||
token = _allow_private_constructor.set(True)
|
||||
try:
|
||||
tool = ConcreteSearchTool(
|
||||
name="test-search",
|
||||
|
@ -54,7 +56,7 @@ async def search_tool() -> AsyncGenerator[AzureAISearchTool, None]:
|
|||
)
|
||||
yield tool
|
||||
finally:
|
||||
_allow_private_constructor.reset(token) # pyright: ignore[reportPrivateUsage]
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -106,7 +108,7 @@ async def test_search_tool_vector_search() -> None:
|
|||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
token = _allow_private_constructor.set(True) # pyright: ignore[reportPrivateUsage]
|
||||
token = _allow_private_constructor.set(True)
|
||||
try:
|
||||
tool = ConcreteSearchTool(
|
||||
name="vector-search",
|
||||
|
@ -131,7 +133,7 @@ async def test_search_tool_vector_search() -> None:
|
|||
assert results.results[0].content["title"] == "Vector Doc"
|
||||
assert results.results[0].score == 0.95
|
||||
finally:
|
||||
_allow_private_constructor.reset(token) # pyright: ignore[reportPrivateUsage]
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
|
||||
class ConcreteAzureAISearchTool(AzureAISearchTool):
|
||||
|
@ -777,8 +779,6 @@ async def test_search_with_user_provided_vectors() -> None:
|
|||
assert results.results[0].content["title"] == "Vector Result"
|
||||
|
||||
mock_client.search.assert_called_once()
|
||||
_, kwargs = mock_client.search.call_args
|
||||
assert "vector_queries" in kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -1078,3 +1078,297 @@ async def test_search_with_different_query_types() -> None:
|
|||
|
||||
await tool.run(SearchQuery(query="object query"))
|
||||
mock_client.search.assert_called_once()
|
||||
|
||||
|
||||
class MockEmbeddingData:
|
||||
"""Mock for OpenAI embedding data."""
|
||||
|
||||
def __init__(self, embedding: List[float]):
|
||||
self.embedding = embedding
|
||||
|
||||
|
||||
class MockEmbeddingResponse:
|
||||
"""Mock for OpenAI embedding response."""
|
||||
|
||||
def __init__(self, data: List[MockEmbeddingData]):
|
||||
self.data = data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_methods() -> None:
|
||||
"""Test the _get_embedding method with different providers."""
|
||||
|
||||
class TestSearchTool(AzureAISearchTool):
|
||||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
with patch.object(AzureAISearchTool, "_get_embedding", autospec=True) as mock_get_embedding:
|
||||
mock_get_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
tool = TestSearchTool.create_vector_search(
|
||||
name="test_vector_search",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
vector_fields=["embedding"],
|
||||
)
|
||||
|
||||
result = await AzureAISearchTool._get_embedding(tool, "test query") # pyright: ignore[reportPrivateUsage]
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_get_embedding.assert_called_once_with(tool, "test query")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_azure_openai_path() -> None:
|
||||
"""Test the Azure OpenAI path in _get_embedding."""
|
||||
mock_azure_openai = AsyncMock()
|
||||
mock_azure_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
|
||||
|
||||
with (
|
||||
patch("openai.AsyncAzureOpenAI", return_value=mock_azure_openai),
|
||||
patch("azure.identity.DefaultAzureCredential"),
|
||||
patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr,
|
||||
):
|
||||
|
||||
def side_effect(obj: Any, name: str, default: Any = None) -> Any:
|
||||
if name == "embedding_provider":
|
||||
return "azure_openai"
|
||||
elif name == "embedding_model":
|
||||
return "text-embedding-ada-002"
|
||||
elif name == "openai_endpoint":
|
||||
return "https://test.openai.azure.com"
|
||||
elif name == "openai_api_key":
|
||||
return "test-key"
|
||||
return default
|
||||
|
||||
mock_getattr.side_effect = side_effect
|
||||
|
||||
class TestTool(AzureAISearchTool):
|
||||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return await AzureAISearchTool._get_embedding(self, query)
|
||||
|
||||
token = _allow_private_constructor.set(True)
|
||||
try:
|
||||
tool = TestTool(
|
||||
name="test",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"],
|
||||
)
|
||||
|
||||
result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage]
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_azure_openai.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-ada-002", input="test query"
|
||||
)
|
||||
finally:
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_openai_path() -> None:
|
||||
"""Test the OpenAI path in _get_embedding."""
|
||||
mock_openai = AsyncMock()
|
||||
mock_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.4, 0.5, 0.6])])
|
||||
|
||||
with (
|
||||
patch("openai.AsyncOpenAI", return_value=mock_openai),
|
||||
patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr,
|
||||
):
|
||||
|
||||
def side_effect(obj: Any, name: str, default: Any = None) -> Any:
|
||||
if name == "embedding_provider":
|
||||
return "openai"
|
||||
elif name == "embedding_model":
|
||||
return "text-embedding-3-small"
|
||||
elif name == "openai_api_key":
|
||||
return "test-key"
|
||||
return default
|
||||
|
||||
mock_getattr.side_effect = side_effect
|
||||
|
||||
class TestTool(AzureAISearchTool):
|
||||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return await AzureAISearchTool._get_embedding(self, query)
|
||||
|
||||
token = _allow_private_constructor.set(True)
|
||||
try:
|
||||
tool = TestTool(
|
||||
name="test",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"],
|
||||
)
|
||||
|
||||
result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage]
|
||||
assert result == [0.4, 0.5, 0.6]
|
||||
mock_openai.embeddings.create.assert_called_once_with(model="text-embedding-3-small", input="test query")
|
||||
finally:
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_embedding_error_cases_direct() -> None:
|
||||
"""Test error cases in the _get_embedding method."""
|
||||
|
||||
class DirectEmbeddingTool(AzureAISearchTool):
|
||||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return await super()._get_embedding(query)
|
||||
|
||||
token = _allow_private_constructor.set(True)
|
||||
try:
|
||||
tool = DirectEmbeddingTool(
|
||||
name="error_embedding_test",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"],
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="To use vector search, you must provide embedding_provider and embedding_model"
|
||||
):
|
||||
await tool._get_embedding("test query")
|
||||
|
||||
tool.search_config.embedding_provider = "azure_openai"
|
||||
with pytest.raises(
|
||||
ValueError, match="To use vector search, you must provide embedding_provider and embedding_model"
|
||||
):
|
||||
await tool._get_embedding("test query")
|
||||
|
||||
tool.search_config.embedding_model = "text-embedding-ada-002"
|
||||
|
||||
def missing_endpoint_side_effect(obj: Any, name: str, default: Any = None) -> Any:
|
||||
if name == "openai_endpoint":
|
||||
return None
|
||||
return getattr(obj, name, default)
|
||||
|
||||
with patch(
|
||||
"autogen_ext.tools.azure._ai_search.getattr",
|
||||
side_effect=missing_endpoint_side_effect,
|
||||
):
|
||||
with pytest.raises(ValueError, match="OpenAI endpoint must be provided"):
|
||||
await tool._get_embedding("test query")
|
||||
|
||||
tool.search_config.embedding_provider = "unsupported_provider"
|
||||
|
||||
def unsupported_provider_side_effect(obj: Any, name: str, default: Any = None) -> Any:
|
||||
if name == "openai_endpoint":
|
||||
return "https://test.openai.azure.com"
|
||||
return getattr(obj, name, default)
|
||||
|
||||
with patch(
|
||||
"autogen_ext.tools.azure._ai_search.getattr",
|
||||
side_effect=unsupported_provider_side_effect,
|
||||
):
|
||||
with pytest.raises(ValueError, match="Unsupported embedding provider"):
|
||||
await tool._get_embedding("test query")
|
||||
finally:
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_openai_with_default_credential() -> None:
|
||||
"""Test Azure OpenAI with DefaultAzureCredential."""
|
||||
|
||||
mock_azure_openai = AsyncMock()
|
||||
mock_azure_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
|
||||
|
||||
mock_credential = MagicMock()
|
||||
mock_token = MagicMock()
|
||||
mock_token.token = "mock-token"
|
||||
mock_credential.get_token.return_value = mock_token
|
||||
|
||||
with (
|
||||
patch("openai.AsyncAzureOpenAI") as mock_azure_openai_class,
|
||||
patch("azure.identity.DefaultAzureCredential", return_value=mock_credential),
|
||||
patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr,
|
||||
):
|
||||
mock_azure_openai_class.return_value = mock_azure_openai
|
||||
|
||||
def side_effect(obj: Any, name: str, default: Any = None) -> Any:
|
||||
if name == "embedding_provider":
|
||||
return "azure_openai"
|
||||
elif name == "embedding_model":
|
||||
return "text-embedding-ada-002"
|
||||
elif name == "openai_endpoint":
|
||||
return "https://test.openai.azure.com"
|
||||
elif name == "openai_api_version":
|
||||
return "2023-05-15"
|
||||
return default
|
||||
|
||||
mock_getattr.side_effect = side_effect
|
||||
|
||||
class TestTool(AzureAISearchTool):
|
||||
async def _get_embedding(self, query: str) -> List[float]:
|
||||
return await AzureAISearchTool._get_embedding(self, query)
|
||||
|
||||
token = _allow_private_constructor.set(True)
|
||||
try:
|
||||
tool = TestTool(
|
||||
name="test",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"],
|
||||
)
|
||||
|
||||
token_provider: Optional[Callable[[], str]] = None
|
||||
|
||||
def capture_token_provider(
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncMock:
|
||||
nonlocal token_provider
|
||||
if azure_ad_token_provider:
|
||||
token_provider = azure_ad_token_provider
|
||||
return mock_azure_openai
|
||||
|
||||
mock_azure_openai_class.side_effect = capture_token_provider
|
||||
|
||||
result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage]
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
|
||||
assert token_provider is not None
|
||||
token_provider()
|
||||
mock_credential.get_token.assert_called_once_with("https://cognitiveservices.azure.com/.default")
|
||||
|
||||
mock_azure_openai.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-ada-002", input="test query"
|
||||
)
|
||||
finally:
|
||||
_allow_private_constructor.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_property() -> None:
|
||||
"""Test the schema property correctly defines the JSON schema for the tool."""
|
||||
tool = ConcreteAzureAISearchTool.create_keyword_search(
|
||||
name="schema_test",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
)
|
||||
|
||||
schema = tool.schema
|
||||
|
||||
assert schema["name"] == "schema_test"
|
||||
assert "description" in schema
|
||||
|
||||
parameters = schema.get("parameters", {}) # pyright: ignore
|
||||
assert parameters.get("type") == "object" # pyright: ignore
|
||||
|
||||
properties = parameters.get("properties", {}) # pyright: ignore
|
||||
assert "query" in properties # pyright: ignore
|
||||
|
||||
required = parameters.get("required", []) # pyright: ignore
|
||||
assert "query" in required # pyright: ignore
|
||||
|
||||
assert schema.get("strict") is True # pyright: ignore
|
||||
|
|
Loading…
Reference in New Issue