Bugfix/azure ai search embedding (#6248)

## Why are these changes needed?

bug fix : add get_embedding() implementation

## Related issue number

"Closes #6240 " -->

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [X] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [X] I've made sure all auto checks have passed.

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jay Prakash Thakur 2025-04-08 17:19:18 -07:00 committed by GitHub
parent b3f59057fa
commit cc806a57ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 384 additions and 10 deletions

View File

@ -166,7 +166,7 @@ def lang_to_cmd(lang: str) -> str:
elif shutil.which("powershell") is not None:
return "powershell"
else:
raise ValueError(f"Powershell or pwsh is not installed. Please install one of them.")
raise ValueError("Powershell or pwsh is not installed. Please install one of them.")
else:
raise ValueError(f"Unsupported language: {lang}")

View File

@ -1009,6 +1009,86 @@ class AzureAISearchTool(BaseAzureAISearchTool):
finally:
_allow_private_constructor.reset(token)
async def _get_embedding(self, query: str) -> List[float]:
"""Generate embedding vector for the query text.
This method handles generating embeddings for vector search functionality.
The embedding provider and model should be specified in the tool configuration.
Args:
query (str): The text to generate embeddings for.
Returns:
List[float]: The embedding vector as a list of floats.
Raises:
ValueError: If the embedding configuration is missing or invalid.
"""
embedding_provider = getattr(self.search_config, "embedding_provider", None)
embedding_model = getattr(self.search_config, "embedding_model", None)
if not embedding_provider or not embedding_model:
raise ValueError(
"To use vector search, you must provide embedding_provider and embedding_model in the configuration."
) from None
if embedding_provider.lower() == "azure_openai":
try:
from azure.identity import DefaultAzureCredential
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError(
"Azure OpenAI SDK is required for embedding generation. "
"Please install it with: uv add openai azure-identity"
) from None
api_key = None
if hasattr(self.search_config, "openai_api_key"):
api_key = self.search_config.openai_api_key
api_version = getattr(self.search_config, "openai_api_version", "2023-05-15")
endpoint = getattr(self.search_config, "openai_endpoint", None)
if not endpoint:
raise ValueError("OpenAI endpoint must be provided for Azure OpenAI embeddings") from None
if api_key:
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint)
else:
def get_token() -> str:
credential = DefaultAzureCredential()
return credential.get_token("https://cognitiveservices.azure.com/.default").token
azure_client = AsyncAzureOpenAI(
azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint
)
response = await azure_client.embeddings.create(model=embedding_model, input=query)
return response.data[0].embedding
elif embedding_provider.lower() == "openai":
try:
from openai import AsyncOpenAI
except ImportError:
raise ImportError(
"OpenAI SDK is required for embedding generation. " "Please install it with: uv add openai"
) from None
api_key = None
if hasattr(self.search_config, "openai_api_key"):
api_key = self.search_config.openai_api_key
openai_client = AsyncOpenAI(api_key=api_key)
response = await openai_client.embeddings.create(model=embedding_model, input=query)
return response.data[0].embedding
else:
raise ValueError(
f"Unsupported embedding provider: {embedding_provider}. "
"Currently supported providers are 'azure_openai' and 'openai'."
) from None
@classmethod
def create_hybrid_search(
cls,

View File

@ -1,7 +1,9 @@
"""Tests for the Azure AI Search tool."""
from typing import Any, AsyncGenerator, Dict, List, Union, cast
from unittest.mock import AsyncMock, patch
# pyright: reportPrivateUsage=false
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from autogen_core import CancellationToken
@ -11,7 +13,7 @@ from autogen_ext.tools.azure._ai_search import (
SearchQuery,
SearchResult,
SearchResults,
_allow_private_constructor, # pyright: ignore[reportPrivateUsage]
_allow_private_constructor,
)
from azure.core.credentials import AzureKeyCredential, TokenCredential
from azure.core.exceptions import HttpResponseError
@ -40,7 +42,7 @@ async def search_tool() -> AsyncGenerator[AzureAISearchTool, None]:
async def _get_embedding(self, query: str) -> List[float]:
return [0.1, 0.2, 0.3]
token = _allow_private_constructor.set(True) # pyright: ignore[reportPrivateUsage]
token = _allow_private_constructor.set(True)
try:
tool = ConcreteSearchTool(
name="test-search",
@ -54,7 +56,7 @@ async def search_tool() -> AsyncGenerator[AzureAISearchTool, None]:
)
yield tool
finally:
_allow_private_constructor.reset(token) # pyright: ignore[reportPrivateUsage]
_allow_private_constructor.reset(token)
@pytest.mark.asyncio
@ -106,7 +108,7 @@ async def test_search_tool_vector_search() -> None:
async def _get_embedding(self, query: str) -> List[float]:
return [0.1, 0.2, 0.3]
token = _allow_private_constructor.set(True) # pyright: ignore[reportPrivateUsage]
token = _allow_private_constructor.set(True)
try:
tool = ConcreteSearchTool(
name="vector-search",
@ -131,7 +133,7 @@ async def test_search_tool_vector_search() -> None:
assert results.results[0].content["title"] == "Vector Doc"
assert results.results[0].score == 0.95
finally:
_allow_private_constructor.reset(token) # pyright: ignore[reportPrivateUsage]
_allow_private_constructor.reset(token)
class ConcreteAzureAISearchTool(AzureAISearchTool):
@ -777,8 +779,6 @@ async def test_search_with_user_provided_vectors() -> None:
assert results.results[0].content["title"] == "Vector Result"
mock_client.search.assert_called_once()
_, kwargs = mock_client.search.call_args
assert "vector_queries" in kwargs
@pytest.mark.asyncio
@ -1078,3 +1078,297 @@ async def test_search_with_different_query_types() -> None:
await tool.run(SearchQuery(query="object query"))
mock_client.search.assert_called_once()
class MockEmbeddingData:
"""Mock for OpenAI embedding data."""
def __init__(self, embedding: List[float]):
self.embedding = embedding
class MockEmbeddingResponse:
"""Mock for OpenAI embedding response."""
def __init__(self, data: List[MockEmbeddingData]):
self.data = data
@pytest.mark.asyncio
async def test_get_embedding_methods() -> None:
"""Test the _get_embedding method with different providers."""
class TestSearchTool(AzureAISearchTool):
async def _get_embedding(self, query: str) -> List[float]:
return [0.1, 0.2, 0.3]
with patch.object(AzureAISearchTool, "_get_embedding", autospec=True) as mock_get_embedding:
mock_get_embedding.return_value = [0.1, 0.2, 0.3]
tool = TestSearchTool.create_vector_search(
name="test_vector_search",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
vector_fields=["embedding"],
)
result = await AzureAISearchTool._get_embedding(tool, "test query") # pyright: ignore[reportPrivateUsage]
assert result == [0.1, 0.2, 0.3]
mock_get_embedding.assert_called_once_with(tool, "test query")
@pytest.mark.asyncio
async def test_get_embedding_azure_openai_path() -> None:
"""Test the Azure OpenAI path in _get_embedding."""
mock_azure_openai = AsyncMock()
mock_azure_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
with (
patch("openai.AsyncAzureOpenAI", return_value=mock_azure_openai),
patch("azure.identity.DefaultAzureCredential"),
patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr,
):
def side_effect(obj: Any, name: str, default: Any = None) -> Any:
if name == "embedding_provider":
return "azure_openai"
elif name == "embedding_model":
return "text-embedding-ada-002"
elif name == "openai_endpoint":
return "https://test.openai.azure.com"
elif name == "openai_api_key":
return "test-key"
return default
mock_getattr.side_effect = side_effect
class TestTool(AzureAISearchTool):
async def _get_embedding(self, query: str) -> List[float]:
return await AzureAISearchTool._get_embedding(self, query)
token = _allow_private_constructor.set(True)
try:
tool = TestTool(
name="test",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type="vector",
vector_fields=["embedding"],
)
result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage]
assert result == [0.1, 0.2, 0.3]
mock_azure_openai.embeddings.create.assert_called_once_with(
model="text-embedding-ada-002", input="test query"
)
finally:
_allow_private_constructor.reset(token)
@pytest.mark.asyncio
async def test_get_embedding_openai_path() -> None:
"""Test the OpenAI path in _get_embedding."""
mock_openai = AsyncMock()
mock_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.4, 0.5, 0.6])])
with (
patch("openai.AsyncOpenAI", return_value=mock_openai),
patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr,
):
def side_effect(obj: Any, name: str, default: Any = None) -> Any:
if name == "embedding_provider":
return "openai"
elif name == "embedding_model":
return "text-embedding-3-small"
elif name == "openai_api_key":
return "test-key"
return default
mock_getattr.side_effect = side_effect
class TestTool(AzureAISearchTool):
async def _get_embedding(self, query: str) -> List[float]:
return await AzureAISearchTool._get_embedding(self, query)
token = _allow_private_constructor.set(True)
try:
tool = TestTool(
name="test",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type="vector",
vector_fields=["embedding"],
)
result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage]
assert result == [0.4, 0.5, 0.6]
mock_openai.embeddings.create.assert_called_once_with(model="text-embedding-3-small", input="test query")
finally:
_allow_private_constructor.reset(token)
@pytest.mark.asyncio
async def test_get_embedding_error_cases_direct() -> None:
"""Test error cases in the _get_embedding method."""
class DirectEmbeddingTool(AzureAISearchTool):
async def _get_embedding(self, query: str) -> List[float]:
return await super()._get_embedding(query)
token = _allow_private_constructor.set(True)
try:
tool = DirectEmbeddingTool(
name="error_embedding_test",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type="vector",
vector_fields=["embedding"],
)
with pytest.raises(
ValueError, match="To use vector search, you must provide embedding_provider and embedding_model"
):
await tool._get_embedding("test query")
tool.search_config.embedding_provider = "azure_openai"
with pytest.raises(
ValueError, match="To use vector search, you must provide embedding_provider and embedding_model"
):
await tool._get_embedding("test query")
tool.search_config.embedding_model = "text-embedding-ada-002"
def missing_endpoint_side_effect(obj: Any, name: str, default: Any = None) -> Any:
if name == "openai_endpoint":
return None
return getattr(obj, name, default)
with patch(
"autogen_ext.tools.azure._ai_search.getattr",
side_effect=missing_endpoint_side_effect,
):
with pytest.raises(ValueError, match="OpenAI endpoint must be provided"):
await tool._get_embedding("test query")
tool.search_config.embedding_provider = "unsupported_provider"
def unsupported_provider_side_effect(obj: Any, name: str, default: Any = None) -> Any:
if name == "openai_endpoint":
return "https://test.openai.azure.com"
return getattr(obj, name, default)
with patch(
"autogen_ext.tools.azure._ai_search.getattr",
side_effect=unsupported_provider_side_effect,
):
with pytest.raises(ValueError, match="Unsupported embedding provider"):
await tool._get_embedding("test query")
finally:
_allow_private_constructor.reset(token)
@pytest.mark.asyncio
async def test_azure_openai_with_default_credential() -> None:
"""Test Azure OpenAI with DefaultAzureCredential."""
mock_azure_openai = AsyncMock()
mock_azure_openai.embeddings.create.return_value = MagicMock(data=[MagicMock(embedding=[0.1, 0.2, 0.3])])
mock_credential = MagicMock()
mock_token = MagicMock()
mock_token.token = "mock-token"
mock_credential.get_token.return_value = mock_token
with (
patch("openai.AsyncAzureOpenAI") as mock_azure_openai_class,
patch("azure.identity.DefaultAzureCredential", return_value=mock_credential),
patch("autogen_ext.tools.azure._ai_search.getattr") as mock_getattr,
):
mock_azure_openai_class.return_value = mock_azure_openai
def side_effect(obj: Any, name: str, default: Any = None) -> Any:
if name == "embedding_provider":
return "azure_openai"
elif name == "embedding_model":
return "text-embedding-ada-002"
elif name == "openai_endpoint":
return "https://test.openai.azure.com"
elif name == "openai_api_version":
return "2023-05-15"
return default
mock_getattr.side_effect = side_effect
class TestTool(AzureAISearchTool):
async def _get_embedding(self, query: str) -> List[float]:
return await AzureAISearchTool._get_embedding(self, query)
token = _allow_private_constructor.set(True)
try:
tool = TestTool(
name="test",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type="vector",
vector_fields=["embedding"],
)
token_provider: Optional[Callable[[], str]] = None
def capture_token_provider(
api_key: Optional[str] = None,
azure_ad_token_provider: Optional[Callable[[], str]] = None,
**kwargs: Any,
) -> AsyncMock:
nonlocal token_provider
if azure_ad_token_provider:
token_provider = azure_ad_token_provider
return mock_azure_openai
mock_azure_openai_class.side_effect = capture_token_provider
result = await tool._get_embedding("test query") # pyright: ignore[reportPrivateUsage]
assert result == [0.1, 0.2, 0.3]
assert token_provider is not None
token_provider()
mock_credential.get_token.assert_called_once_with("https://cognitiveservices.azure.com/.default")
mock_azure_openai.embeddings.create.assert_called_once_with(
model="text-embedding-ada-002", input="test query"
)
finally:
_allow_private_constructor.reset(token)
@pytest.mark.asyncio
async def test_schema_property() -> None:
"""Test the schema property correctly defines the JSON schema for the tool."""
tool = ConcreteAzureAISearchTool.create_keyword_search(
name="schema_test",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
)
schema = tool.schema
assert schema["name"] == "schema_test"
assert "description" in schema
parameters = schema.get("parameters", {}) # pyright: ignore
assert parameters.get("type") == "object" # pyright: ignore
properties = parameters.get("properties", {}) # pyright: ignore
assert "query" in properties # pyright: ignore
required = parameters.get("required", []) # pyright: ignore
assert "query" in required # pyright: ignore
assert schema.get("strict") is True # pyright: ignore