Add Azure AI Search tool implementation (#5844)

# Azure AI Search Tool Implementation

This PR adds a new tool for Azure AI Search integration to autogen-ext,
enabling agents to search and retrieve information from Azure AI Search
indexes.

## Why Are These Changes Needed?
AutoGen currently lacks native integration with Azure AI Search, which
is a powerful enterprise search service that supports semantic, vector,
and hybrid search capabilities. This integration enables agents to:
1. Retrieve relevant information from large document collections
2. Perform semantic search with AI-powered ranking
3. Execute vector similarity search using embeddings
4. Combine text and vector approaches for optimal results

This tool complements existing retrieval capabilities and provides a
seamless way to integrate with Azure's search infrastructure.

## Features
- **Multiple Search Types**: Support for text, semantic, vector, and
hybrid search
- **Flexible Configuration**: Customizable search parameters and fields
- **Robust Error Handling**: User-friendly error messages with
actionable guidance
- **Performance Optimizations**: Configurable caching and retry
mechanisms
- **Vector Search Support**: Built-in embedding generation with
extensibility

## Usage Example
```python
from autogen_ext.tools.azure import AzureAISearchTool
from azure.core.credentials import AzureKeyCredential
from autogen import AssistantAgent, UserProxyAgent
# Create the search tool
search_tool = AzureAISearchTool.load_component({
   "provider": "autogen_ext.tools.azure.AzureAISearchTool",
   "config": {
       "name": "DocumentSearch",
       "description": "Search for information in the knowledge base",
       "endpoint": "https://your-service.search.windows.net",
       "index_name": "your-index",
       "credential": {"api_key": "your-api-key"},
       "query_type": "semantic",
       "semantic_config_name": "default"
   }
})
# Create an agent with the search tool
assistant = AssistantAgent(
   "assistant",
   llm_config={"tools": [search_tool]}
)
# Create a user proxy agent
user_proxy = UserProxyAgent(
   "user_proxy",
   human_input_mode="TERMINATE",
   max_consecutive_auto_reply=10,
   code_execution_config={"work_dir": "coding"}
)
# Start the conversation
user_proxy.initiate_chat(
   assistant,
   message="What information do we have about quantum computing in our knowledge base?"
)
```

## Testing
- Added unit tests for all search types (text, semantic, vector, hybrid)
- Added tests for error handling and cancellation
- All tests pass locally

## Documentation
- Added comprehensive docstrings with examples
- Included warnings about placeholder embedding implementation
- Added links to Azure AI Search documentation

## Related issue number

Closes #5419 

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jay Prakash Thakur 2025-04-02 16:16:48 -07:00 committed by GitHub
parent d7f2b56846
commit 0d9b574d09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 2684 additions and 0 deletions

View File

@ -55,6 +55,7 @@ python/autogen_ext.models.anthropic
python/autogen_ext.models.semantic_kernel
python/autogen_ext.models.ollama
python/autogen_ext.models.llama_cpp
python/autogen_ext.tools.azure
python/autogen_ext.tools.code_execution
python/autogen_ext.tools.graphrag
python/autogen_ext.tools.http

View File

@ -0,0 +1,8 @@
autogen\_ext.tools.azure
========================
.. automodule:: autogen_ext.tools.azure
:members:
:undoc-members:
:show-inheritance:

View File

@ -25,6 +25,7 @@ azure = [
"azure-ai-inference>=1.0.0b7",
"azure-core",
"azure-identity",
"azure-search-documents>=11.4.0",
]
docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"]
ollama = ["ollama>=0.4.7", "tiktoken>=0.8.0"]

View File

@ -0,0 +1,17 @@
from ._ai_search import (
AzureAISearchTool,
BaseAzureAISearchTool,
SearchQuery,
SearchResult,
SearchResults,
)
from ._config import AzureAISearchConfig
__all__ = [
"AzureAISearchTool",
"BaseAzureAISearchTool",
"SearchQuery",
"SearchResult",
"SearchResults",
"AzureAISearchConfig",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,177 @@
"""Configuration for Azure AI Search tool.
This module provides configuration classes for the Azure AI Search tool, including
settings for authentication, search behavior, retry policies, and caching.
"""
import logging
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from azure.core.credentials import AzureKeyCredential, TokenCredential
from pydantic import BaseModel, Field, model_validator
# Add explicit ignore for the specific model validator error
# pyright: reportArgumentType=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownVariableType=false
T = TypeVar("T", bound="AzureAISearchConfig")
logger = logging.getLogger(__name__)
class AzureAISearchConfig(BaseModel):
"""Configuration for Azure AI Search tool.
This class defines the configuration parameters for :class:`AzureAISearchTool`.
It provides options for customizing search behavior including query types,
field selection, authentication, retry policies, and caching strategies.
.. note::
This class requires the :code:`azure` extra for the :code:`autogen-ext` package.
.. code-block:: bash
pip install -U "autogen-ext[azure]"
Example:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from autogen_ext.tools.azure import AzureAISearchConfig
config = AzureAISearchConfig(
name="doc_search",
endpoint="https://my-search.search.windows.net",
index_name="my-index",
credential=AzureKeyCredential("<your-key>"),
query_type="vector",
vector_fields=["embedding"],
)
For more details, see:
* `Azure AI Search Overview <https://learn.microsoft.com/azure/search/search-what-is-azure-search>`_
* `Vector Search <https://learn.microsoft.com/azure/search/vector-search-overview>`_
Args:
name (str): Name for the tool instance, used to identify it in the agent's toolkit.
description (Optional[str]): Human-readable description of what this tool does and how to use it.
endpoint (str): The full URL of your Azure AI Search service, in the format
'https://<service-name>.search.windows.net'.
index_name (str): Name of the target search index in your Azure AI Search service.
The index must be pre-created and properly configured.
api_version (str): Azure AI Search REST API version to use. Defaults to '2023-11-01'.
Only change if you need specific features from a different API version.
credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential:
- AzureKeyCredential: For API key authentication (admin/query key)
- TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential)
query_type (Literal["keyword", "fulltext", "vector", "hybrid"]): The search query mode to use:
- 'keyword': Basic keyword search (default)
- 'full': Full Lucene query syntax
- 'vector': Vector similarity search
- 'hybrid': Hybrid search combining multiple techniques
search_fields (Optional[List[str]]): List of index fields to search within. If not specified,
searches all searchable fields. Example: ['title', 'content'].
select_fields (Optional[List[str]]): Fields to return in search results. If not specified,
returns all fields. Use to optimize response size.
vector_fields (Optional[List[str]]): Vector field names for vector search. Must be configured
in your search index as vector fields. Required for vector search.
top (Optional[int]): Maximum number of documents to return in search results.
Helps control response size and processing time.
retry_enabled (bool): Whether to enable retry policy for transient errors. Defaults to True.
retry_max_attempts (Optional[int]): Maximum number of retry attempts for failed requests. Defaults to 3.
retry_mode (Literal["fixed", "exponential"]): Retry backoff strategy: fixed or exponential. Defaults to "exponential".
enable_caching (bool): Whether to enable client-side caching of search results. Defaults to False.
cache_ttl_seconds (int): Time-to-live for cached search results in seconds. Defaults to 300 (5 minutes).
filter (Optional[str]): OData filter expression to refine search results.
"""
name: str = Field(description="The name of the tool")
description: Optional[str] = Field(default=None, description="A description of the tool")
endpoint: str = Field(description="The endpoint URL for your Azure AI Search service")
index_name: str = Field(description="The name of the search index to query")
api_version: str = Field(default="2023-11-01", description="API version to use")
credential: Union[AzureKeyCredential, TokenCredential] = Field(
description="The credential to use for authentication"
)
query_type: Literal["keyword", "fulltext", "vector", "hybrid"] = Field(
default="keyword", description="Type of query to perform"
)
search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in")
select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results")
vector_fields: Optional[List[str]] = Field(
default=None, description="Optional list of vector fields for vector search"
)
top: Optional[int] = Field(default=None, description="Optional number of results to return")
filter: Optional[str] = Field(default=None, description="Optional OData filter expression to refine search results")
retry_enabled: bool = Field(default=True, description="Whether to enable retry policy for transient errors")
retry_max_attempts: Optional[int] = Field(
default=3, description="Maximum number of retry attempts for failed requests"
)
retry_mode: Literal["fixed", "exponential"] = Field(
default="exponential",
description="Retry backoff strategy: fixed or exponential",
)
enable_caching: bool = Field(
default=False,
description="Whether to enable client-side caching of search results",
)
cache_ttl_seconds: int = Field(
default=300, # 5 minutes
description="Time-to-live for cached search results in seconds",
)
embedding_provider: Optional[str] = Field(
default=None,
description="Name of embedding provider to use (e.g., 'azure_openai', 'openai')",
)
embedding_model: Optional[str] = Field(default=None, description="Model name to use for generating embeddings")
embedding_dimension: Optional[int] = Field(
default=None, description="Dimension of embedding vectors produced by the model"
)
model_config = {"arbitrary_types_allowed": True}
@classmethod
@model_validator(mode="before")
def validate_credentials(cls: Type[T], data: Any) -> Any:
"""Validate and convert credential data."""
if not isinstance(data, dict):
return data
result = {}
for key, value in data.items():
result[str(key)] = value
if "credential" in result:
credential = result["credential"]
if isinstance(credential, dict) and "api_key" in credential:
api_key = str(credential["api_key"])
result["credential"] = AzureKeyCredential(api_key)
return result
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""Custom model_dump to handle credentials."""
result: Dict[str, Any] = super().model_dump(**kwargs)
if isinstance(self.credential, AzureKeyCredential):
result["credential"] = {"type": "AzureKeyCredential"}
elif isinstance(self.credential, TokenCredential):
result["credential"] = {"type": "TokenCredential"}
return result

View File

@ -0,0 +1,303 @@
"""Test fixtures for Azure AI Search tool tests."""
import warnings
from typing import Any, Dict, Generator, List, Protocol, Type, TypeVar, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from autogen_core import ComponentModel
T = TypeVar("T")
class AccessTokenProtocol(Protocol):
"""Protocol matching Azure AccessToken."""
token: str
expires_on: int
class MockAccessToken:
"""Mock implementation of AccessToken."""
def __init__(self, token: str, expires_on: int) -> None:
self.token = token
self.expires_on = expires_on
class MockAzureKeyCredential:
"""Mock implementation of AzureKeyCredential."""
def __init__(self, key: str) -> None:
self.key = key
class MockTokenCredential:
"""Mock implementation of TokenCredential for testing."""
def get_token(
self,
*scopes: str,
claims: str | None = None,
tenant_id: str | None = None,
enable_cae: bool = False,
**kwargs: Any,
) -> AccessTokenProtocol:
"""Mock get_token method that implements TokenCredential protocol."""
return MockAccessToken("mock-token", 12345)
try:
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
_access_token_type: Type[AccessToken] = AccessToken
azure_sdk_available = True
except ImportError:
AzureKeyCredential = MockAzureKeyCredential # type: ignore
TokenCredential = MockTokenCredential # type: ignore
_access_token_type = MockAccessToken # type: ignore
azure_sdk_available = False
CredentialType = Union[AzureKeyCredential, TokenCredential, MockAzureKeyCredential, MockTokenCredential, Any]
needs_azure_sdk = pytest.mark.skipif(not azure_sdk_available, reason="Azure SDK not available")
warnings.filterwarnings(
"ignore",
message="Type google.*uses PyType_Spec with a metaclass that has custom tp_new",
category=DeprecationWarning,
)
@pytest.fixture
def mock_vectorized_query() -> Generator[MagicMock, None, None]:
"""Create a mock VectorizedQuery for testing."""
with patch("azure.search.documents.models.VectorizedQuery") as mock:
yield mock
@pytest.fixture
def test_config() -> ComponentModel:
"""Return a test configuration for the Azure AI Search tool."""
return ComponentModel(
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
config={
"name": "TestAzureSearch",
"description": "Test Azure AI Search Tool",
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"api_version": "2023-10-01-Preview",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"query_type": "keyword",
"search_fields": ["content", "title"],
"select_fields": ["id", "content", "title", "source"],
"top": 5,
},
)
@pytest.fixture
def keyword_config() -> ComponentModel:
"""Return a keyword search configuration."""
return ComponentModel(
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
config={
"name": "KeywordSearch",
"description": "Keyword search tool",
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"query_type": "keyword",
"search_fields": ["content", "title"],
"select_fields": ["id", "content", "title", "source"],
},
)
@pytest.fixture
def vector_config() -> ComponentModel:
"""Create a test configuration for vector search."""
return ComponentModel(
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
config={
"name": "VectorSearch",
"description": "Vector search tool",
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"api_version": "2023-10-01-Preview",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"query_type": "vector",
"vector_fields": ["embedding"],
"select_fields": ["id", "content", "title", "source"],
"top": 5,
},
)
@pytest.fixture
def hybrid_config() -> ComponentModel:
"""Create a test configuration for hybrid search."""
return ComponentModel(
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
config={
"name": "HybridSearch",
"description": "Hybrid search tool",
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"api_version": "2023-10-01-Preview",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"query_type": "keyword",
"search_fields": ["content", "title"],
"vector_fields": ["embedding"],
"select_fields": ["id", "content", "title", "source"],
"top": 5,
},
)
@pytest.fixture
def mock_search_response() -> List[Dict[str, Any]]:
"""Create a mock search response."""
return [
{
"@search.score": 0.95,
"id": "doc1",
"content": "This is the first document content",
"title": "Document 1",
"source": "test-source-1",
},
{
"@search.score": 0.85,
"id": "doc2",
"content": "This is the second document content",
"title": "Document 2",
"source": "test-source-2",
},
]
class AsyncIterator:
"""Async iterator for testing."""
def __init__(self, items: List[Dict[str, Any]]) -> None:
self.items = items.copy()
def __aiter__(self) -> "AsyncIterator":
return self
async def __anext__(self) -> Dict[str, Any]:
if not self.items:
raise StopAsyncIteration
return self.items.pop(0)
async def get_count(self) -> int:
"""Return count of items."""
return len(self.items)
@pytest.fixture
def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> tuple[MagicMock, Any]:
"""Create a mock search client for testing."""
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
search_results = AsyncIterator(mock_search_response)
mock_client.search = MagicMock(return_value=search_results)
patcher = patch("azure.search.documents.aio.SearchClient", return_value=mock_client)
return mock_client, patcher
def test_validate_credentials_scenarios() -> None:
"""Test all validate_credentials scenarios to ensure full code coverage."""
import sys
from autogen_ext.tools.azure._config import AzureAISearchConfig
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
if module_path is not None:
assert "autogen-ext" in module_path
data: Any = "not a dict"
result: Any = AzureAISearchConfig.validate_credentials(data) # type: ignore
assert result == data
data_empty: Dict[str, Any] = {}
result_empty: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_empty) # type: ignore
assert isinstance(result_empty, dict)
data_items: Dict[str, Any] = {"key1": "value1", "key2": "value2"}
result_items: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_items) # type: ignore
assert result_items["key1"] == "value1"
assert result_items["key2"] == "value2"
data_with_api_key: Dict[str, Any] = {
"name": "test",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": {"api_key": "test-key"},
}
result_with_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_api_key) # type: ignore
cred = result_with_api_key["credential"] # type: ignore
assert isinstance(cred, (AzureKeyCredential, MockAzureKeyCredential))
assert hasattr(cred, "key")
assert cred.key == "test-key" # type: ignore
credential: Any = AzureKeyCredential("test-key")
data_with_credential: Dict[str, Any] = {
"name": "test",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": credential,
}
result_with_credential: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_credential) # type: ignore
assert result_with_credential["credential"] is credential
data_without_api_key: Dict[str, Any] = {
"name": "test",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": {"username": "test-user", "password": "test-pass"},
}
result_without_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_without_api_key) # type: ignore
assert result_without_api_key["credential"] == {"username": "test-user", "password": "test-pass"}
def test_model_dump_scenarios() -> None:
"""Test all model_dump scenarios to ensure full code coverage."""
import sys
from autogen_ext.tools.azure._config import AzureAISearchConfig
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
if module_path is not None:
assert "autogen-ext" in module_path
config = AzureAISearchConfig(
name="test",
endpoint="https://endpoint",
index_name="index",
credential=AzureKeyCredential("key"), # type: ignore
)
result = config.model_dump()
assert result["credential"] == {"type": "AzureKeyCredential"}
if azure_sdk_available:
from azure.core.credentials import AccessToken
from azure.core.credentials import TokenCredential as RealTokenCredential
class TestTokenCredential(RealTokenCredential):
def get_token(self, *args: Any, **kwargs: Any) -> AccessToken:
"""Override of get_token method that returns proper type."""
return AccessToken("test-token", 12345)
config = AzureAISearchConfig(
name="test", endpoint="https://endpoint", index_name="index", credential=TestTokenCredential()
)
result = config.model_dump()
assert result["credential"] == {"type": "TokenCredential"}
else:
pytest.skip("Skipping TokenCredential test - Azure SDK not available")

File diff suppressed because it is too large Load Diff

View File

@ -596,6 +596,7 @@ azure = [
{ name = "azure-ai-inference" },
{ name = "azure-core" },
{ name = "azure-identity" },
{ name = "azure-search-documents" },
]
chromadb = [
{ name = "chromadb" },
@ -732,6 +733,7 @@ requires-dist = [
{ name = "azure-ai-inference", marker = "extra == 'azure'", specifier = ">=1.0.0b7" },
{ name = "azure-core", marker = "extra == 'azure'" },
{ name = "azure-identity", marker = "extra == 'azure'" },
{ name = "azure-search-documents", marker = "extra == 'azure'", specifier = ">=11.4.0" },
{ name = "chromadb", marker = "extra == 'chromadb'" },
{ name = "chromadb", marker = "extra == 'task-centric-memory'", specifier = ">=0.6.3" },
{ name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" },