mirror of https://github.com/microsoft/autogen.git
Add Azure AI Search tool implementation (#5844)
# Azure AI Search Tool Implementation This PR adds a new tool for Azure AI Search integration to autogen-ext, enabling agents to search and retrieve information from Azure AI Search indexes. ## Why Are These Changes Needed? AutoGen currently lacks native integration with Azure AI Search, which is a powerful enterprise search service that supports semantic, vector, and hybrid search capabilities. This integration enables agents to: 1. Retrieve relevant information from large document collections 2. Perform semantic search with AI-powered ranking 3. Execute vector similarity search using embeddings 4. Combine text and vector approaches for optimal results This tool complements existing retrieval capabilities and provides a seamless way to integrate with Azure's search infrastructure. ## Features - **Multiple Search Types**: Support for text, semantic, vector, and hybrid search - **Flexible Configuration**: Customizable search parameters and fields - **Robust Error Handling**: User-friendly error messages with actionable guidance - **Performance Optimizations**: Configurable caching and retry mechanisms - **Vector Search Support**: Built-in embedding generation with extensibility ## Usage Example ```python from autogen_ext.tools.azure import AzureAISearchTool from azure.core.credentials import AzureKeyCredential from autogen import AssistantAgent, UserProxyAgent # Create the search tool search_tool = AzureAISearchTool.load_component({ "provider": "autogen_ext.tools.azure.AzureAISearchTool", "config": { "name": "DocumentSearch", "description": "Search for information in the knowledge base", "endpoint": "https://your-service.search.windows.net", "index_name": "your-index", "credential": {"api_key": "your-api-key"}, "query_type": "semantic", "semantic_config_name": "default" } }) # Create an agent with the search tool assistant = AssistantAgent( "assistant", llm_config={"tools": [search_tool]} ) # Create a user proxy agent user_proxy = UserProxyAgent( "user_proxy", human_input_mode="TERMINATE", max_consecutive_auto_reply=10, code_execution_config={"work_dir": "coding"} ) # Start the conversation user_proxy.initiate_chat( assistant, message="What information do we have about quantum computing in our knowledge base?" ) ``` ## Testing - Added unit tests for all search types (text, semantic, vector, hybrid) - Added tests for error handling and cancellation - All tests pass locally ## Documentation - Added comprehensive docstrings with examples - Included warnings about placeholder embedding implementation - Added links to Azure AI Search documentation ## Related issue number Closes #5419 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
d7f2b56846
commit
0d9b574d09
|
@ -55,6 +55,7 @@ python/autogen_ext.models.anthropic
|
|||
python/autogen_ext.models.semantic_kernel
|
||||
python/autogen_ext.models.ollama
|
||||
python/autogen_ext.models.llama_cpp
|
||||
python/autogen_ext.tools.azure
|
||||
python/autogen_ext.tools.code_execution
|
||||
python/autogen_ext.tools.graphrag
|
||||
python/autogen_ext.tools.http
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
autogen\_ext.tools.azure
|
||||
========================
|
||||
|
||||
.. automodule:: autogen_ext.tools.azure
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
|
@ -25,6 +25,7 @@ azure = [
|
|||
"azure-ai-inference>=1.0.0b7",
|
||||
"azure-core",
|
||||
"azure-identity",
|
||||
"azure-search-documents>=11.4.0",
|
||||
]
|
||||
docker = ["docker~=7.0", "asyncio_atexit>=1.0.1"]
|
||||
ollama = ["ollama>=0.4.7", "tiktoken>=0.8.0"]
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from ._ai_search import (
|
||||
AzureAISearchTool,
|
||||
BaseAzureAISearchTool,
|
||||
SearchQuery,
|
||||
SearchResult,
|
||||
SearchResults,
|
||||
)
|
||||
from ._config import AzureAISearchConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAISearchTool",
|
||||
"BaseAzureAISearchTool",
|
||||
"SearchQuery",
|
||||
"SearchResult",
|
||||
"SearchResults",
|
||||
"AzureAISearchConfig",
|
||||
]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,177 @@
|
|||
"""Configuration for Azure AI Search tool.
|
||||
|
||||
This module provides configuration classes for the Azure AI Search tool, including
|
||||
settings for authentication, search behavior, retry policies, and caching.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential, TokenCredential
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
# Add explicit ignore for the specific model validator error
|
||||
# pyright: reportArgumentType=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownVariableType=false
|
||||
|
||||
T = TypeVar("T", bound="AzureAISearchConfig")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAISearchConfig(BaseModel):
|
||||
"""Configuration for Azure AI Search tool.
|
||||
|
||||
This class defines the configuration parameters for :class:`AzureAISearchTool`.
|
||||
It provides options for customizing search behavior including query types,
|
||||
field selection, authentication, retry policies, and caching strategies.
|
||||
|
||||
.. note::
|
||||
|
||||
This class requires the :code:`azure` extra for the :code:`autogen-ext` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "autogen-ext[azure]"
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from autogen_ext.tools.azure import AzureAISearchConfig
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="doc_search",
|
||||
endpoint="https://my-search.search.windows.net",
|
||||
index_name="my-index",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"],
|
||||
)
|
||||
|
||||
For more details, see:
|
||||
* `Azure AI Search Overview <https://learn.microsoft.com/azure/search/search-what-is-azure-search>`_
|
||||
* `Vector Search <https://learn.microsoft.com/azure/search/vector-search-overview>`_
|
||||
|
||||
Args:
|
||||
name (str): Name for the tool instance, used to identify it in the agent's toolkit.
|
||||
description (Optional[str]): Human-readable description of what this tool does and how to use it.
|
||||
endpoint (str): The full URL of your Azure AI Search service, in the format
|
||||
'https://<service-name>.search.windows.net'.
|
||||
index_name (str): Name of the target search index in your Azure AI Search service.
|
||||
The index must be pre-created and properly configured.
|
||||
api_version (str): Azure AI Search REST API version to use. Defaults to '2023-11-01'.
|
||||
Only change if you need specific features from a different API version.
|
||||
credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential:
|
||||
- AzureKeyCredential: For API key authentication (admin/query key)
|
||||
- TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential)
|
||||
query_type (Literal["keyword", "fulltext", "vector", "hybrid"]): The search query mode to use:
|
||||
- 'keyword': Basic keyword search (default)
|
||||
- 'full': Full Lucene query syntax
|
||||
- 'vector': Vector similarity search
|
||||
- 'hybrid': Hybrid search combining multiple techniques
|
||||
search_fields (Optional[List[str]]): List of index fields to search within. If not specified,
|
||||
searches all searchable fields. Example: ['title', 'content'].
|
||||
select_fields (Optional[List[str]]): Fields to return in search results. If not specified,
|
||||
returns all fields. Use to optimize response size.
|
||||
vector_fields (Optional[List[str]]): Vector field names for vector search. Must be configured
|
||||
in your search index as vector fields. Required for vector search.
|
||||
top (Optional[int]): Maximum number of documents to return in search results.
|
||||
Helps control response size and processing time.
|
||||
retry_enabled (bool): Whether to enable retry policy for transient errors. Defaults to True.
|
||||
retry_max_attempts (Optional[int]): Maximum number of retry attempts for failed requests. Defaults to 3.
|
||||
retry_mode (Literal["fixed", "exponential"]): Retry backoff strategy: fixed or exponential. Defaults to "exponential".
|
||||
enable_caching (bool): Whether to enable client-side caching of search results. Defaults to False.
|
||||
cache_ttl_seconds (int): Time-to-live for cached search results in seconds. Defaults to 300 (5 minutes).
|
||||
filter (Optional[str]): OData filter expression to refine search results.
|
||||
"""
|
||||
|
||||
name: str = Field(description="The name of the tool")
|
||||
description: Optional[str] = Field(default=None, description="A description of the tool")
|
||||
endpoint: str = Field(description="The endpoint URL for your Azure AI Search service")
|
||||
index_name: str = Field(description="The name of the search index to query")
|
||||
api_version: str = Field(default="2023-11-01", description="API version to use")
|
||||
credential: Union[AzureKeyCredential, TokenCredential] = Field(
|
||||
description="The credential to use for authentication"
|
||||
)
|
||||
query_type: Literal["keyword", "fulltext", "vector", "hybrid"] = Field(
|
||||
default="keyword", description="Type of query to perform"
|
||||
)
|
||||
search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in")
|
||||
select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results")
|
||||
vector_fields: Optional[List[str]] = Field(
|
||||
default=None, description="Optional list of vector fields for vector search"
|
||||
)
|
||||
top: Optional[int] = Field(default=None, description="Optional number of results to return")
|
||||
filter: Optional[str] = Field(default=None, description="Optional OData filter expression to refine search results")
|
||||
|
||||
retry_enabled: bool = Field(default=True, description="Whether to enable retry policy for transient errors")
|
||||
retry_max_attempts: Optional[int] = Field(
|
||||
default=3, description="Maximum number of retry attempts for failed requests"
|
||||
)
|
||||
retry_mode: Literal["fixed", "exponential"] = Field(
|
||||
default="exponential",
|
||||
description="Retry backoff strategy: fixed or exponential",
|
||||
)
|
||||
|
||||
enable_caching: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable client-side caching of search results",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=300, # 5 minutes
|
||||
description="Time-to-live for cached search results in seconds",
|
||||
)
|
||||
|
||||
embedding_provider: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of embedding provider to use (e.g., 'azure_openai', 'openai')",
|
||||
)
|
||||
embedding_model: Optional[str] = Field(default=None, description="Model name to use for generating embeddings")
|
||||
embedding_dimension: Optional[int] = Field(
|
||||
default=None, description="Dimension of embedding vectors produced by the model"
|
||||
)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@classmethod
|
||||
@model_validator(mode="before")
|
||||
def validate_credentials(cls: Type[T], data: Any) -> Any:
|
||||
"""Validate and convert credential data."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
result = {}
|
||||
|
||||
for key, value in data.items():
|
||||
result[str(key)] = value
|
||||
|
||||
if "credential" in result:
|
||||
credential = result["credential"]
|
||||
|
||||
if isinstance(credential, dict) and "api_key" in credential:
|
||||
api_key = str(credential["api_key"])
|
||||
result["credential"] = AzureKeyCredential(api_key)
|
||||
|
||||
return result
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Custom model_dump to handle credentials."""
|
||||
result: Dict[str, Any] = super().model_dump(**kwargs)
|
||||
|
||||
if isinstance(self.credential, AzureKeyCredential):
|
||||
result["credential"] = {"type": "AzureKeyCredential"}
|
||||
elif isinstance(self.credential, TokenCredential):
|
||||
result["credential"] = {"type": "TokenCredential"}
|
||||
|
||||
return result
|
|
@ -0,0 +1,303 @@
|
|||
"""Test fixtures for Azure AI Search tool tests."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Generator, List, Protocol, Type, TypeVar, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from autogen_core import ComponentModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AccessTokenProtocol(Protocol):
|
||||
"""Protocol matching Azure AccessToken."""
|
||||
|
||||
token: str
|
||||
expires_on: int
|
||||
|
||||
|
||||
class MockAccessToken:
|
||||
"""Mock implementation of AccessToken."""
|
||||
|
||||
def __init__(self, token: str, expires_on: int) -> None:
|
||||
self.token = token
|
||||
self.expires_on = expires_on
|
||||
|
||||
|
||||
class MockAzureKeyCredential:
|
||||
"""Mock implementation of AzureKeyCredential."""
|
||||
|
||||
def __init__(self, key: str) -> None:
|
||||
self.key = key
|
||||
|
||||
|
||||
class MockTokenCredential:
|
||||
"""Mock implementation of TokenCredential for testing."""
|
||||
|
||||
def get_token(
|
||||
self,
|
||||
*scopes: str,
|
||||
claims: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
enable_cae: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AccessTokenProtocol:
|
||||
"""Mock get_token method that implements TokenCredential protocol."""
|
||||
return MockAccessToken("mock-token", 12345)
|
||||
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
||||
|
||||
_access_token_type: Type[AccessToken] = AccessToken
|
||||
azure_sdk_available = True
|
||||
except ImportError:
|
||||
AzureKeyCredential = MockAzureKeyCredential # type: ignore
|
||||
TokenCredential = MockTokenCredential # type: ignore
|
||||
_access_token_type = MockAccessToken # type: ignore
|
||||
azure_sdk_available = False
|
||||
|
||||
CredentialType = Union[AzureKeyCredential, TokenCredential, MockAzureKeyCredential, MockTokenCredential, Any]
|
||||
|
||||
needs_azure_sdk = pytest.mark.skipif(not azure_sdk_available, reason="Azure SDK not available")
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Type google.*uses PyType_Spec with a metaclass that has custom tp_new",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vectorized_query() -> Generator[MagicMock, None, None]:
|
||||
"""Create a mock VectorizedQuery for testing."""
|
||||
with patch("azure.search.documents.models.VectorizedQuery") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config() -> ComponentModel:
|
||||
"""Return a test configuration for the Azure AI Search tool."""
|
||||
return ComponentModel(
|
||||
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
|
||||
config={
|
||||
"name": "TestAzureSearch",
|
||||
"description": "Test Azure AI Search Tool",
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"api_version": "2023-10-01-Preview",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"query_type": "keyword",
|
||||
"search_fields": ["content", "title"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
"top": 5,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def keyword_config() -> ComponentModel:
|
||||
"""Return a keyword search configuration."""
|
||||
return ComponentModel(
|
||||
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
|
||||
config={
|
||||
"name": "KeywordSearch",
|
||||
"description": "Keyword search tool",
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"query_type": "keyword",
|
||||
"search_fields": ["content", "title"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_config() -> ComponentModel:
|
||||
"""Create a test configuration for vector search."""
|
||||
return ComponentModel(
|
||||
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
|
||||
config={
|
||||
"name": "VectorSearch",
|
||||
"description": "Vector search tool",
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"api_version": "2023-10-01-Preview",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"query_type": "vector",
|
||||
"vector_fields": ["embedding"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
"top": 5,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hybrid_config() -> ComponentModel:
|
||||
"""Create a test configuration for hybrid search."""
|
||||
return ComponentModel(
|
||||
provider="autogen_ext.tools.azure.MockAzureAISearchTool",
|
||||
config={
|
||||
"name": "HybridSearch",
|
||||
"description": "Hybrid search tool",
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"api_version": "2023-10-01-Preview",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"query_type": "keyword",
|
||||
"search_fields": ["content", "title"],
|
||||
"vector_fields": ["embedding"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
"top": 5,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_response() -> List[Dict[str, Any]]:
|
||||
"""Create a mock search response."""
|
||||
return [
|
||||
{
|
||||
"@search.score": 0.95,
|
||||
"id": "doc1",
|
||||
"content": "This is the first document content",
|
||||
"title": "Document 1",
|
||||
"source": "test-source-1",
|
||||
},
|
||||
{
|
||||
"@search.score": 0.85,
|
||||
"id": "doc2",
|
||||
"content": "This is the second document content",
|
||||
"title": "Document 2",
|
||||
"source": "test-source-2",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class AsyncIterator:
|
||||
"""Async iterator for testing."""
|
||||
|
||||
def __init__(self, items: List[Dict[str, Any]]) -> None:
|
||||
self.items = items.copy()
|
||||
|
||||
def __aiter__(self) -> "AsyncIterator":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Dict[str, Any]:
|
||||
if not self.items:
|
||||
raise StopAsyncIteration
|
||||
return self.items.pop(0)
|
||||
|
||||
async def get_count(self) -> int:
|
||||
"""Return count of items."""
|
||||
return len(self.items)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> tuple[MagicMock, Any]:
|
||||
"""Create a mock search client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
search_results = AsyncIterator(mock_search_response)
|
||||
mock_client.search = MagicMock(return_value=search_results)
|
||||
|
||||
patcher = patch("azure.search.documents.aio.SearchClient", return_value=mock_client)
|
||||
|
||||
return mock_client, patcher
|
||||
|
||||
|
||||
def test_validate_credentials_scenarios() -> None:
|
||||
"""Test all validate_credentials scenarios to ensure full code coverage."""
|
||||
import sys
|
||||
|
||||
from autogen_ext.tools.azure._config import AzureAISearchConfig
|
||||
|
||||
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
|
||||
if module_path is not None:
|
||||
assert "autogen-ext" in module_path
|
||||
|
||||
data: Any = "not a dict"
|
||||
result: Any = AzureAISearchConfig.validate_credentials(data) # type: ignore
|
||||
assert result == data
|
||||
|
||||
data_empty: Dict[str, Any] = {}
|
||||
result_empty: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_empty) # type: ignore
|
||||
assert isinstance(result_empty, dict)
|
||||
|
||||
data_items: Dict[str, Any] = {"key1": "value1", "key2": "value2"}
|
||||
result_items: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_items) # type: ignore
|
||||
assert result_items["key1"] == "value1"
|
||||
assert result_items["key2"] == "value2"
|
||||
|
||||
data_with_api_key: Dict[str, Any] = {
|
||||
"name": "test",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": {"api_key": "test-key"},
|
||||
}
|
||||
result_with_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_api_key) # type: ignore
|
||||
|
||||
cred = result_with_api_key["credential"] # type: ignore
|
||||
assert isinstance(cred, (AzureKeyCredential, MockAzureKeyCredential))
|
||||
assert hasattr(cred, "key")
|
||||
assert cred.key == "test-key" # type: ignore
|
||||
|
||||
credential: Any = AzureKeyCredential("test-key")
|
||||
data_with_credential: Dict[str, Any] = {
|
||||
"name": "test",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": credential,
|
||||
}
|
||||
result_with_credential: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_credential) # type: ignore
|
||||
assert result_with_credential["credential"] is credential
|
||||
|
||||
data_without_api_key: Dict[str, Any] = {
|
||||
"name": "test",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": {"username": "test-user", "password": "test-pass"},
|
||||
}
|
||||
result_without_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_without_api_key) # type: ignore
|
||||
assert result_without_api_key["credential"] == {"username": "test-user", "password": "test-pass"}
|
||||
|
||||
|
||||
def test_model_dump_scenarios() -> None:
|
||||
"""Test all model_dump scenarios to ensure full code coverage."""
|
||||
import sys
|
||||
|
||||
from autogen_ext.tools.azure._config import AzureAISearchConfig
|
||||
|
||||
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
|
||||
if module_path is not None:
|
||||
assert "autogen-ext" in module_path
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test",
|
||||
endpoint="https://endpoint",
|
||||
index_name="index",
|
||||
credential=AzureKeyCredential("key"), # type: ignore
|
||||
)
|
||||
result = config.model_dump()
|
||||
assert result["credential"] == {"type": "AzureKeyCredential"}
|
||||
|
||||
if azure_sdk_available:
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.core.credentials import TokenCredential as RealTokenCredential
|
||||
|
||||
class TestTokenCredential(RealTokenCredential):
|
||||
def get_token(self, *args: Any, **kwargs: Any) -> AccessToken:
|
||||
"""Override of get_token method that returns proper type."""
|
||||
return AccessToken("test-token", 12345)
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test", endpoint="https://endpoint", index_name="index", credential=TestTokenCredential()
|
||||
)
|
||||
result = config.model_dump()
|
||||
assert result["credential"] == {"type": "TokenCredential"}
|
||||
else:
|
||||
pytest.skip("Skipping TokenCredential test - Azure SDK not available")
|
File diff suppressed because it is too large
Load Diff
|
@ -596,6 +596,7 @@ azure = [
|
|||
{ name = "azure-ai-inference" },
|
||||
{ name = "azure-core" },
|
||||
{ name = "azure-identity" },
|
||||
{ name = "azure-search-documents" },
|
||||
]
|
||||
chromadb = [
|
||||
{ name = "chromadb" },
|
||||
|
@ -732,6 +733,7 @@ requires-dist = [
|
|||
{ name = "azure-ai-inference", marker = "extra == 'azure'", specifier = ">=1.0.0b7" },
|
||||
{ name = "azure-core", marker = "extra == 'azure'" },
|
||||
{ name = "azure-identity", marker = "extra == 'azure'" },
|
||||
{ name = "azure-search-documents", marker = "extra == 'azure'", specifier = ">=11.4.0" },
|
||||
{ name = "chromadb", marker = "extra == 'chromadb'" },
|
||||
{ name = "chromadb", marker = "extra == 'task-centric-memory'", specifier = ">=0.6.3" },
|
||||
{ name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" },
|
||||
|
|
Loading…
Reference in New Issue