mirror of https://github.com/microsoft/autogen.git
Feature add Add LlamaCppChatCompletionClient and llama-cpp (#5326)
This pull request introduces the integration of the `llama-cpp` library into the `autogen-ext` package, with significant changes to the project dependencies and the implementation of a new chat completion client. The most important changes include updating the project dependencies, adding a new module for the `LlamaCppChatCompletionClient`, and implementing the client with various functionalities. ### Project Dependencies: * [`python/packages/autogen-ext/pyproject.toml`](diffhunk://#diff-095119d4420ff09059557bd25681211d1772c2be0fbe0ff2d551a3726eff1b4bR34-R38): Added `llama-cpp-python` as a new dependency under the `llama-cpp` section. ### New Module: * [`python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py`](diffhunk://#diff-42ae3ba17d51ca917634c4ea3c5969cf930297c288a783f8d9c126f2accef71dR1-R8): Introduced the `LlamaCppChatCompletionClient` class and handled import errors with a descriptive message for missing dependencies. ### Implementation of `LlamaCppChatCompletionClient`: * `python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py`: - Added the `LlamaCppChatCompletionClient` class with methods to initialize the client, create chat completions, detect and execute tools, and handle streaming responses. - Included detailed logging for debugging purposes and implemented methods to count tokens, track usage, and provide model information.…d chat capabilities <!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [X ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [X ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ X] I've made sure all auto checks have passed. --------- Co-authored-by: aribornstein <x@x.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
parent
a1858efac9
commit
6a3acc4548
|
@ -197,7 +197,7 @@ jobs:
|
|||
|
||||
- name: Install Python deps
|
||||
run: |
|
||||
uv sync --locked --all-extras
|
||||
uv sync --locked --all-extras --no-extra llama-cpp
|
||||
shell: pwsh
|
||||
working-directory: ./python
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ python/autogen_ext.models.azure
|
|||
python/autogen_ext.models.anthropic
|
||||
python/autogen_ext.models.semantic_kernel
|
||||
python/autogen_ext.models.ollama
|
||||
python/autogen_ext.models.llama_cpp
|
||||
python/autogen_ext.tools.code_execution
|
||||
python/autogen_ext.tools.graphrag
|
||||
python/autogen_ext.tools.http
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
autogen\_ext.models.llama\_cpp
|
||||
==============================
|
||||
|
||||
|
||||
.. automodule:: autogen_ext.models.llama_cpp
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
:member-order: bysource
|
|
@ -32,6 +32,11 @@ file-surfer = [
|
|||
"autogen-agentchat==0.4.8",
|
||||
"markitdown~=0.0.1",
|
||||
]
|
||||
|
||||
llama-cpp = [
|
||||
"llama-cpp-python>=0.1.9",
|
||||
]
|
||||
|
||||
graphrag = ["graphrag>=1.0.1"]
|
||||
chromadb = ["chromadb"]
|
||||
web-surfer = [
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
try:
|
||||
from ._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Dependencies for Llama Cpp not found. "
|
||||
"Please install llama-cpp-python: "
|
||||
"pip install autogen-ext[llama-cpp]"
|
||||
) from e
|
||||
|
||||
__all__ = ["LlamaCppChatCompletionClient"]
|
|
@ -0,0 +1,426 @@
|
|||
import logging # added import
|
||||
import re
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union, cast
|
||||
|
||||
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, MessageHandlerContext
|
||||
from autogen_core.logging import LLMCallEvent
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
CreateResult,
|
||||
FinishReasons,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from llama_cpp import (
|
||||
ChatCompletionFunctionParameters,
|
||||
ChatCompletionRequestAssistantMessage,
|
||||
ChatCompletionRequestFunctionMessage,
|
||||
ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestUserMessage,
|
||||
ChatCompletionTool,
|
||||
ChatCompletionToolFunction,
|
||||
Llama,
|
||||
llama_chat_format,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME) # initialize logger
|
||||
|
||||
|
||||
def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||
if stop_reason is None:
|
||||
return "unknown"
|
||||
|
||||
# Convert to lower case
|
||||
stop_reason = stop_reason.lower()
|
||||
|
||||
KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
|
||||
"stop": "stop",
|
||||
"length": "length",
|
||||
"content_filter": "content_filter",
|
||||
"function_calls": "function_calls",
|
||||
"end_turn": "stop",
|
||||
"tool_calls": "function_calls",
|
||||
}
|
||||
|
||||
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""
|
||||
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||
|
||||
Prefer _assert_valid_name for validating user configuration or input
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
def assert_valid_name(name: str) -> str:
|
||||
"""
|
||||
Ensure that configured names are valid, raises ValueError if not.
|
||||
|
||||
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||
"""
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||
if len(name) > 64:
|
||||
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def convert_tools(
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
) -> List[ChatCompletionTool]:
|
||||
result: List[ChatCompletionTool] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, Tool):
|
||||
tool_schema = tool.schema
|
||||
else:
|
||||
assert isinstance(tool, dict)
|
||||
tool_schema = tool
|
||||
|
||||
result.append(
|
||||
ChatCompletionTool(
|
||||
type="function",
|
||||
function=ChatCompletionToolFunction(
|
||||
name=tool_schema["name"],
|
||||
description=(tool_schema["description"] if "description" in tool_schema else ""),
|
||||
parameters=(
|
||||
cast(ChatCompletionFunctionParameters, tool_schema["parameters"])
|
||||
if "parameters" in tool_schema
|
||||
else {}
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
# Check if all tools have valid names.
|
||||
for tool_param in result:
|
||||
assert_valid_name(tool_param["function"]["name"])
|
||||
return result
|
||||
|
||||
|
||||
class LlamaCppParams(TypedDict, total=False):
|
||||
# from_pretrained parameters:
|
||||
repo_id: Optional[str]
|
||||
filename: Optional[str]
|
||||
additional_files: Optional[List[Any]]
|
||||
local_dir: Optional[str]
|
||||
local_dir_use_symlinks: Union[bool, Literal["auto"]]
|
||||
cache_dir: Optional[str]
|
||||
# __init__ parameters:
|
||||
model_path: str
|
||||
n_gpu_layers: int
|
||||
split_mode: int
|
||||
main_gpu: int
|
||||
tensor_split: Optional[List[float]]
|
||||
rpc_servers: Optional[str]
|
||||
vocab_only: bool
|
||||
use_mmap: bool
|
||||
use_mlock: bool
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]]
|
||||
seed: int
|
||||
n_ctx: int
|
||||
n_batch: int
|
||||
n_ubatch: int
|
||||
n_threads: Optional[int]
|
||||
n_threads_batch: Optional[int]
|
||||
rope_scaling_type: Optional[int]
|
||||
pooling_type: int
|
||||
rope_freq_base: float
|
||||
rope_freq_scale: float
|
||||
yarn_ext_factor: float
|
||||
yarn_attn_factor: float
|
||||
yarn_beta_fast: float
|
||||
yarn_beta_slow: float
|
||||
yarn_orig_ctx: int
|
||||
logits_all: bool
|
||||
embedding: bool
|
||||
offload_kqv: bool
|
||||
flash_attn: bool
|
||||
no_perf: bool
|
||||
last_n_tokens_size: int
|
||||
lora_base: Optional[str]
|
||||
lora_scale: float
|
||||
lora_path: Optional[str]
|
||||
numa: Union[bool, int]
|
||||
chat_format: Optional[str]
|
||||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler]
|
||||
draft_model: Optional[Any] # LlamaDraftModel not exposed by llama_cpp
|
||||
tokenizer: Optional[Any] # BaseLlamaTokenizer not exposed by llama_cpp
|
||||
type_k: Optional[int]
|
||||
type_v: Optional[int]
|
||||
spm_infill: bool
|
||||
verbose: bool
|
||||
|
||||
|
||||
class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
"""Chat completion client for LlamaCpp models.
|
||||
To use this client, you must install the `llama-cpp` extra:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install "autogen-ext[llama-cpp]"
|
||||
|
||||
This client allows you to interact with LlamaCpp models, either by specifying a local model path or by downloading a model from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_path (optional, str): The path to the LlamaCpp model file. Required if repo_id and filename are not provided.
|
||||
repo_id (optional, str): The Hugging Face Hub repository ID. Required if model_path is not provided.
|
||||
filename (optional, str): The filename of the model within the Hugging Face Hub repository. Required if model_path is not provided.
|
||||
n_gpu_layers (optional, int): The number of layers to put on the GPU.
|
||||
n_ctx (optional, int): The context size.
|
||||
n_batch (optional, int): The batch size.
|
||||
verbose (optional, bool): Whether to print verbose output.
|
||||
model_info (optional, ModelInfo): The capabilities of the model. Defaults to a ModelInfo instance with function_calling set to True.
|
||||
**kwargs: Additional parameters to pass to the Llama class.
|
||||
|
||||
Examples:
|
||||
|
||||
The following code snippet shows how to use the client with a local model file:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from autogen_core.models import UserMessage
|
||||
from autogen_ext.models.llama_cpp import LlamaCppChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
llama_client = LlamaCppChatCompletionClient(model_path="/path/to/your/model.gguf")
|
||||
result = await llama_client.create([UserMessage(content="What is the capital of France?", source="user")])
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
The following code snippet shows how to use the client with a model from Hugging Face Hub:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from autogen_core.models import UserMessage
|
||||
from autogen_ext.models.llama_cpp import LlamaCppChatCompletionClient
|
||||
|
||||
|
||||
async def main():
|
||||
llama_client = LlamaCppChatCompletionClient(
|
||||
repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||
)
|
||||
result = await llama_client.create([UserMessage(content="What is the capital of France?", source="user")])
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
**kwargs: Unpack[LlamaCppParams],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the LlamaCpp client.
|
||||
"""
|
||||
|
||||
if model_info:
|
||||
validate_model_info(model_info)
|
||||
|
||||
if "repo_id" in kwargs and "filename" in kwargs and kwargs["repo_id"] and kwargs["filename"]:
|
||||
repo_id: str = cast(str, kwargs.pop("repo_id"))
|
||||
filename: str = cast(str, kwargs.pop("filename"))
|
||||
pretrained = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) # type: ignore
|
||||
assert isinstance(pretrained, Llama)
|
||||
self.llm = pretrained
|
||||
|
||||
elif "model_path" in kwargs:
|
||||
self.llm = Llama(**kwargs) # pyright: ignore[reportUnknownMemberType]
|
||||
else:
|
||||
raise ValueError("Please provide model_path if ... or provide repo_id and filename if ....")
|
||||
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
# Convert LLMMessage objects to dictionaries with 'role' and 'content'
|
||||
# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = []
|
||||
converted_messages: list[
|
||||
ChatCompletionRequestSystemMessage
|
||||
| ChatCompletionRequestUserMessage
|
||||
| ChatCompletionRequestAssistantMessage
|
||||
| ChatCompletionRequestUserMessage
|
||||
| ChatCompletionRequestToolMessage
|
||||
| ChatCompletionRequestFunctionMessage
|
||||
] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
converted_messages.append({"role": "system", "content": msg.content})
|
||||
elif isinstance(msg, UserMessage) and isinstance(msg.content, str):
|
||||
converted_messages.append({"role": "user", "content": msg.content})
|
||||
elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str):
|
||||
converted_messages.append({"role": "assistant", "content": msg.content})
|
||||
elif (
|
||||
isinstance(msg, SystemMessage) or isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage)
|
||||
) and isinstance(msg.content, list):
|
||||
raise ValueError("Multi-part messages such as those containing images are currently not supported.")
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
|
||||
if self.model_info["function_calling"]:
|
||||
response = self.llm.create_chat_completion(
|
||||
messages=converted_messages, tools=convert_tools(tools), stream=False
|
||||
)
|
||||
else:
|
||||
response = self.llm.create_chat_completion(messages=converted_messages, stream=False)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Unexpected response type from LlamaCpp model.")
|
||||
|
||||
self._total_usage["prompt_tokens"] += response["usage"]["prompt_tokens"]
|
||||
self._total_usage["completion_tokens"] += response["usage"]["completion_tokens"]
|
||||
|
||||
# Parse the response
|
||||
response_tool_calls: ChatCompletionTool | None = None
|
||||
response_text: str | None = None
|
||||
if "choices" in response and len(response["choices"]) > 0:
|
||||
if "message" in response["choices"][0]:
|
||||
response_text = response["choices"][0]["message"]["content"]
|
||||
if "tool_calls" in response["choices"][0]:
|
||||
response_tool_calls = response["choices"][0]["tool_calls"] # type: ignore
|
||||
|
||||
content: List[FunctionCall] | str = ""
|
||||
thought: str | None = None
|
||||
if response_tool_calls:
|
||||
content = []
|
||||
for tool_call in response_tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
raise ValueError("Unexpected tool call type from LlamaCpp model.")
|
||||
content.append(
|
||||
FunctionCall(
|
||||
id=tool_call["id"],
|
||||
arguments=tool_call["function"]["arguments"],
|
||||
name=normalize_name(tool_call["function"]["name"]),
|
||||
)
|
||||
)
|
||||
if response_text and len(response_text) > 0:
|
||||
thought = response_text
|
||||
else:
|
||||
if response_text:
|
||||
content = response_text
|
||||
|
||||
# Detect tool usage in the response
|
||||
if not response_tool_calls and not response_text:
|
||||
logger.debug("DEBUG: No response text found. Returning empty response.")
|
||||
return CreateResult(
|
||||
content="", usage=RequestUsage(prompt_tokens=0, completion_tokens=0), finish_reason="stop", cached=False
|
||||
)
|
||||
|
||||
# Create a CreateResult object
|
||||
if "finish_reason" in response["choices"][0]:
|
||||
finish_reason = response["choices"][0]["finish_reason"]
|
||||
else:
|
||||
finish_reason = "unknown"
|
||||
if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"):
|
||||
finish_reason = "unknown"
|
||||
create_result = CreateResult(
|
||||
content=content,
|
||||
thought=thought,
|
||||
usage=cast(RequestUsage, response["usage"]),
|
||||
finish_reason=normalize_stop_reason(finish_reason), # type: ignore
|
||||
cached=False,
|
||||
)
|
||||
|
||||
# If we are running in the context of a handler we can get the agent_id
|
||||
try:
|
||||
agent_id = MessageHandlerContext.agent_id()
|
||||
except RuntimeError:
|
||||
agent_id = None
|
||||
|
||||
logger.info(
|
||||
LLMCallEvent(
|
||||
messages=cast(List[Dict[str, Any]], converted_messages),
|
||||
response=create_result.model_dump(),
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
completion_tokens=response["usage"]["completion_tokens"],
|
||||
agent_id=agent_id,
|
||||
)
|
||||
)
|
||||
return create_result
|
||||
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
|
||||
yield ""
|
||||
|
||||
# Implement abstract methods
|
||||
def actual_usage(self) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
prompt_tokens=self._total_usage.get("prompt_tokens", 0),
|
||||
completion_tokens=self._total_usage.get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelInfo:
|
||||
return self.model_info
|
||||
|
||||
def count_tokens(
|
||||
self,
|
||||
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
total = 0
|
||||
for msg in messages:
|
||||
# Use the Llama model's tokenizer to encode the content
|
||||
tokens = self.llm.tokenize(str(msg.content).encode("utf-8"))
|
||||
total += len(tokens)
|
||||
return total
|
||||
|
||||
@property
|
||||
def model_info(self) -> ModelInfo:
|
||||
return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True)
|
||||
|
||||
def remaining_tokens(
|
||||
self,
|
||||
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
used_tokens = self.count_tokens(messages)
|
||||
return max(self.llm.n_ctx() - used_tokens, 0)
|
||||
|
||||
def total_usage(self) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
prompt_tokens=self._total_usage.get("prompt_tokens", 0),
|
||||
completion_tokens=self._total_usage.get("completion_tokens", 0),
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the LlamaCpp client.
|
||||
"""
|
||||
self.llm.close()
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import venv
|
||||
|
@ -18,7 +18,6 @@ from autogen_core import CancellationToken
|
|||
from autogen_core.code_executor import CodeBlock
|
||||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
|
||||
|
||||
HAS_POWERSHELL: bool = platform.system() == "Windows" and (
|
||||
shutil.which("powershell") is not None or shutil.which("pwsh") is not None
|
||||
)
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
import contextlib
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Generator, List, Sequence, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# from autogen_agentchat.agents import AssistantAgent
|
||||
# from autogen_agentchat.messages import TextMessage
|
||||
# from autogen_core import CancellationToken
|
||||
from autogen_core.models import RequestUsage, SystemMessage, UserMessage
|
||||
|
||||
# from autogen_core.tools import FunctionTool
|
||||
try:
|
||||
from llama_cpp import ChatCompletionMessageToolCalls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
except ImportError:
|
||||
# If llama_cpp is not installed, we can't run the tests.
|
||||
pytest.skip("Skipping LlamaCppChatCompletionClient tests: llama-cpp-python not installed", allow_module_level=True)
|
||||
|
||||
|
||||
# Fake Llama class to simulate responses
|
||||
class FakeLlama:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.n_ctx = lambda: 1024
|
||||
|
||||
# Added tokenize method for testing purposes.
|
||||
def tokenize(self, b: bytes) -> list[int]:
|
||||
return list(b)
|
||||
|
||||
def create_chat_completion(
|
||||
self, messages: Any, tools: List[ChatCompletionMessageToolCalls] | None, stream: bool = False
|
||||
) -> dict[str, Any]:
|
||||
# Return fake non-streaming response.
|
||||
|
||||
return {
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2},
|
||||
"choices": [{"message": {"content": "Fake response"}}],
|
||||
}
|
||||
|
||||
def __call__(self, prompt: str, stream: bool = True) -> Generator[dict[str, Any], None, None]:
|
||||
# Yield fake streaming tokens.
|
||||
yield {"choices": [{"text": "Hello "}]}
|
||||
yield {"choices": [{"text": "World"}]}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@contextlib.contextmanager
|
||||
def get_completion_client(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> "Generator[type[LlamaCppChatCompletionClient], None, None]":
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr("llama_cpp.Llama", FakeLlama)
|
||||
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
|
||||
yield LlamaCppChatCompletionClient
|
||||
sys.modules.pop("autogen_ext.models.llama_cpp._llama_cpp_completion_client", None)
|
||||
sys.modules.pop("llama_cpp", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llama_cpp_create(get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]") -> None:
|
||||
with get_completion_client as Client:
|
||||
client = Client(model_path="dummy")
|
||||
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
SystemMessage(content="Test system"),
|
||||
UserMessage(content="Test user", source="user"),
|
||||
]
|
||||
result = await client.create(messages=messages)
|
||||
assert result.content == "Fake response"
|
||||
usage: RequestUsage = result.usage
|
||||
assert usage.prompt_tokens == 1
|
||||
assert usage.completion_tokens == 2
|
||||
assert result.finish_reason in ("stop", "unknown")
|
||||
|
||||
|
||||
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_llama_cpp_create_stream(
|
||||
# get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]",
|
||||
# ) -> None:
|
||||
# with get_completion_client as Client:
|
||||
# client = Client(filename="dummy")
|
||||
# messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
# SystemMessage(content="Test system"),
|
||||
# UserMessage(content="Test user", source="user"),
|
||||
# ]
|
||||
# collected = ""
|
||||
# async for token in client.create_stream(messages=messages):
|
||||
# collected += token
|
||||
# assert collected == "Hello World"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invalid_message(
|
||||
get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]",
|
||||
) -> None:
|
||||
with get_completion_client as Client:
|
||||
client = Client(model_path="dummy")
|
||||
# Pass an unsupported message type (integer) to trigger ValueError.
|
||||
with pytest.raises(ValueError, match="Unsupported message type"):
|
||||
await client.create(messages=[123]) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_and_remaining_tokens(
|
||||
get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]", monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
with get_completion_client as Client:
|
||||
client = Client(model_path="dummy")
|
||||
msg = SystemMessage(content="Test")
|
||||
# count_tokens should count the bytes
|
||||
token_count = client.count_tokens([msg])
|
||||
# Since "Test" encoded is 4 bytes, expect 4 tokens.
|
||||
assert token_count >= 4
|
||||
remaining = client.remaining_tokens([msg])
|
||||
# remaining should be (1024 - token_count); ensure non-negative.
|
||||
assert remaining == max(1024 - token_count, 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llama_cpp_integration_non_streaming() -> None:
|
||||
if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||
pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||
|
||||
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
|
||||
client = LlamaCppChatCompletionClient(
|
||||
repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||
)
|
||||
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Hello, how are you?", source="user"),
|
||||
]
|
||||
result = await client.create(messages=messages)
|
||||
assert isinstance(result.content, str) and len(result.content.strip()) > 0
|
||||
|
||||
|
||||
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_llama_cpp_integration_streaming() -> None:
|
||||
# if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||
# pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||
|
||||
# from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
# client = LlamaCppChatCompletionClient(
|
||||
# repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||
# )
|
||||
# messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
# SystemMessage(content="You are a helpful assistant."),
|
||||
# UserMessage(content="Please stream your response.", source="user"),
|
||||
# ]
|
||||
# collected = ""
|
||||
# async for token in client.create_stream(messages=messages):
|
||||
# collected += token
|
||||
# assert isinstance(collected, str) and len(collected.strip()) > 0
|
||||
|
||||
# Commented out tool use as this functionality is not yet implemented for Phi-4.
|
||||
# Define tools (functions) for the AssistantAgent
|
||||
# def add(num1: int, num2: int) -> int:
|
||||
# """Add two numbers together"""
|
||||
# return num1 + num2
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_llama_cpp_integration_tool_use() -> None:
|
||||
# if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||
# pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||
|
||||
# from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
|
||||
# model_client = LlamaCppChatCompletionClient(
|
||||
# repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||
# )
|
||||
|
||||
# # Initialize the AssistantAgent
|
||||
# assistant = AssistantAgent(
|
||||
# name="assistant",
|
||||
# system_message=("You can add two numbers together using the `add` function. "),
|
||||
# model_client=model_client,
|
||||
# tools=[
|
||||
# FunctionTool(
|
||||
# add,
|
||||
# description="Add two numbers together. The first argument is num1 and second is num2. The return value is num1 + num2",
|
||||
# )
|
||||
# ],
|
||||
# reflect_on_tool_use=True, # Reflect on tool results
|
||||
# )
|
||||
|
||||
# # Test the tool
|
||||
# response = await assistant.on_messages(
|
||||
# [
|
||||
# TextMessage(content="add 3 and 4", source="user"),
|
||||
# ],
|
||||
# CancellationToken(),
|
||||
# )
|
||||
|
||||
# assert "7" in response.chat_message.content
|
|
@ -628,6 +628,9 @@ jupyter-executor = [
|
|||
langchain = [
|
||||
{ name = "langchain-core" },
|
||||
]
|
||||
llama-cpp = [
|
||||
{ name = "llama-cpp-python" },
|
||||
]
|
||||
magentic-one = [
|
||||
{ name = "autogen-agentchat" },
|
||||
{ name = "markitdown" },
|
||||
|
@ -735,6 +738,7 @@ requires-dist = [
|
|||
{ name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" },
|
||||
{ name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" },
|
||||
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
|
||||
{ name = "llama-cpp-python", marker = "extra == 'llama-cpp'", specifier = ">=0.1.9" },
|
||||
{ name = "markitdown", marker = "extra == 'file-surfer'", specifier = "~=0.0.1" },
|
||||
{ name = "markitdown", marker = "extra == 'magentic-one'", specifier = "~=0.0.1" },
|
||||
{ name = "markitdown", marker = "extra == 'web-surfer'", specifier = "~=0.0.1" },
|
||||
|
@ -3501,6 +3505,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/43/b1/9355547c3b9043ba2821e7797f322c753dfa4d2a3da7bb05690fce689eaa/llama_cloud-0.1.11-py3-none-any.whl", hash = "sha256:b703765d03783a5a0fc57a52adc9892f8b91b0c19bbecb85a54ad4e813342951", size = 250609 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "llama-cpp-python"
|
||||
version = "0.3.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "diskcache" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "numpy" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a6/38/7a47b1fb1d83eaddd86ca8ddaf20f141cbc019faf7b425283d8e5ef710e5/llama_cpp_python-0.3.7.tar.gz", hash = "sha256:0566a0dcc0f38005c4093309a87f67c2452449522e3e17e15cd735a62957894c", size = 66715891 }
|
||||
|
||||
[[package]]
|
||||
name = "llama-index"
|
||||
version = "0.12.14"
|
||||
|
|
Loading…
Reference in New Issue