mirror of https://github.com/microsoft/autogen.git
Feature add Add LlamaCppChatCompletionClient and llama-cpp (#5326)
This pull request introduces the integration of the `llama-cpp` library into the `autogen-ext` package, with significant changes to the project dependencies and the implementation of a new chat completion client. The most important changes include updating the project dependencies, adding a new module for the `LlamaCppChatCompletionClient`, and implementing the client with various functionalities. ### Project Dependencies: * [`python/packages/autogen-ext/pyproject.toml`](diffhunk://#diff-095119d4420ff09059557bd25681211d1772c2be0fbe0ff2d551a3726eff1b4bR34-R38): Added `llama-cpp-python` as a new dependency under the `llama-cpp` section. ### New Module: * [`python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py`](diffhunk://#diff-42ae3ba17d51ca917634c4ea3c5969cf930297c288a783f8d9c126f2accef71dR1-R8): Introduced the `LlamaCppChatCompletionClient` class and handled import errors with a descriptive message for missing dependencies. ### Implementation of `LlamaCppChatCompletionClient`: * `python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py`: - Added the `LlamaCppChatCompletionClient` class with methods to initialize the client, create chat completions, detect and execute tools, and handle streaming responses. - Included detailed logging for debugging purposes and implemented methods to count tokens, track usage, and provide model information.…d chat capabilities <!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [X ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [X ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ X] I've made sure all auto checks have passed. --------- Co-authored-by: aribornstein <x@x.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
This commit is contained in:
parent
a1858efac9
commit
6a3acc4548
|
@ -197,7 +197,7 @@ jobs:
|
||||||
|
|
||||||
- name: Install Python deps
|
- name: Install Python deps
|
||||||
run: |
|
run: |
|
||||||
uv sync --locked --all-extras
|
uv sync --locked --all-extras --no-extra llama-cpp
|
||||||
shell: pwsh
|
shell: pwsh
|
||||||
working-directory: ./python
|
working-directory: ./python
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ python/autogen_ext.models.azure
|
||||||
python/autogen_ext.models.anthropic
|
python/autogen_ext.models.anthropic
|
||||||
python/autogen_ext.models.semantic_kernel
|
python/autogen_ext.models.semantic_kernel
|
||||||
python/autogen_ext.models.ollama
|
python/autogen_ext.models.ollama
|
||||||
|
python/autogen_ext.models.llama_cpp
|
||||||
python/autogen_ext.tools.code_execution
|
python/autogen_ext.tools.code_execution
|
||||||
python/autogen_ext.tools.graphrag
|
python/autogen_ext.tools.graphrag
|
||||||
python/autogen_ext.tools.http
|
python/autogen_ext.tools.http
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
autogen\_ext.models.llama\_cpp
|
||||||
|
==============================
|
||||||
|
|
||||||
|
|
||||||
|
.. automodule:: autogen_ext.models.llama_cpp
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
:member-order: bysource
|
|
@ -32,6 +32,11 @@ file-surfer = [
|
||||||
"autogen-agentchat==0.4.8",
|
"autogen-agentchat==0.4.8",
|
||||||
"markitdown~=0.0.1",
|
"markitdown~=0.0.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
llama-cpp = [
|
||||||
|
"llama-cpp-python>=0.1.9",
|
||||||
|
]
|
||||||
|
|
||||||
graphrag = ["graphrag>=1.0.1"]
|
graphrag = ["graphrag>=1.0.1"]
|
||||||
chromadb = ["chromadb"]
|
chromadb = ["chromadb"]
|
||||||
web-surfer = [
|
web-surfer = [
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
try:
|
||||||
|
from ._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Dependencies for Llama Cpp not found. "
|
||||||
|
"Please install llama-cpp-python: "
|
||||||
|
"pip install autogen-ext[llama-cpp]"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
__all__ = ["LlamaCppChatCompletionClient"]
|
|
@ -0,0 +1,426 @@
|
||||||
|
import logging # added import
|
||||||
|
import re
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union, cast
|
||||||
|
|
||||||
|
from autogen_core import EVENT_LOGGER_NAME, CancellationToken, FunctionCall, MessageHandlerContext
|
||||||
|
from autogen_core.logging import LLMCallEvent
|
||||||
|
from autogen_core.models import (
|
||||||
|
AssistantMessage,
|
||||||
|
ChatCompletionClient,
|
||||||
|
CreateResult,
|
||||||
|
FinishReasons,
|
||||||
|
FunctionExecutionResultMessage,
|
||||||
|
LLMMessage,
|
||||||
|
ModelInfo,
|
||||||
|
RequestUsage,
|
||||||
|
SystemMessage,
|
||||||
|
UserMessage,
|
||||||
|
validate_model_info,
|
||||||
|
)
|
||||||
|
from autogen_core.tools import Tool, ToolSchema
|
||||||
|
from llama_cpp import (
|
||||||
|
ChatCompletionFunctionParameters,
|
||||||
|
ChatCompletionRequestAssistantMessage,
|
||||||
|
ChatCompletionRequestFunctionMessage,
|
||||||
|
ChatCompletionRequestSystemMessage,
|
||||||
|
ChatCompletionRequestToolMessage,
|
||||||
|
ChatCompletionRequestUserMessage,
|
||||||
|
ChatCompletionTool,
|
||||||
|
ChatCompletionToolFunction,
|
||||||
|
Llama,
|
||||||
|
llama_chat_format,
|
||||||
|
)
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
logger = logging.getLogger(EVENT_LOGGER_NAME) # initialize logger
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||||
|
if stop_reason is None:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
# Convert to lower case
|
||||||
|
stop_reason = stop_reason.lower()
|
||||||
|
|
||||||
|
KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
|
||||||
|
"stop": "stop",
|
||||||
|
"length": "length",
|
||||||
|
"content_filter": "content_filter",
|
||||||
|
"function_calls": "function_calls",
|
||||||
|
"end_turn": "stop",
|
||||||
|
"tool_calls": "function_calls",
|
||||||
|
}
|
||||||
|
|
||||||
|
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
|
||||||
|
|
||||||
|
Prefer _assert_valid_name for validating user configuration or input
|
||||||
|
"""
|
||||||
|
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||||
|
|
||||||
|
|
||||||
|
def assert_valid_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensure that configured names are valid, raises ValueError if not.
|
||||||
|
|
||||||
|
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
|
||||||
|
"""
|
||||||
|
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
||||||
|
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
|
||||||
|
if len(name) > 64:
|
||||||
|
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tools(
|
||||||
|
tools: Sequence[Tool | ToolSchema],
|
||||||
|
) -> List[ChatCompletionTool]:
|
||||||
|
result: List[ChatCompletionTool] = []
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, Tool):
|
||||||
|
tool_schema = tool.schema
|
||||||
|
else:
|
||||||
|
assert isinstance(tool, dict)
|
||||||
|
tool_schema = tool
|
||||||
|
|
||||||
|
result.append(
|
||||||
|
ChatCompletionTool(
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolFunction(
|
||||||
|
name=tool_schema["name"],
|
||||||
|
description=(tool_schema["description"] if "description" in tool_schema else ""),
|
||||||
|
parameters=(
|
||||||
|
cast(ChatCompletionFunctionParameters, tool_schema["parameters"])
|
||||||
|
if "parameters" in tool_schema
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Check if all tools have valid names.
|
||||||
|
for tool_param in result:
|
||||||
|
assert_valid_name(tool_param["function"]["name"])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppParams(TypedDict, total=False):
|
||||||
|
# from_pretrained parameters:
|
||||||
|
repo_id: Optional[str]
|
||||||
|
filename: Optional[str]
|
||||||
|
additional_files: Optional[List[Any]]
|
||||||
|
local_dir: Optional[str]
|
||||||
|
local_dir_use_symlinks: Union[bool, Literal["auto"]]
|
||||||
|
cache_dir: Optional[str]
|
||||||
|
# __init__ parameters:
|
||||||
|
model_path: str
|
||||||
|
n_gpu_layers: int
|
||||||
|
split_mode: int
|
||||||
|
main_gpu: int
|
||||||
|
tensor_split: Optional[List[float]]
|
||||||
|
rpc_servers: Optional[str]
|
||||||
|
vocab_only: bool
|
||||||
|
use_mmap: bool
|
||||||
|
use_mlock: bool
|
||||||
|
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]]
|
||||||
|
seed: int
|
||||||
|
n_ctx: int
|
||||||
|
n_batch: int
|
||||||
|
n_ubatch: int
|
||||||
|
n_threads: Optional[int]
|
||||||
|
n_threads_batch: Optional[int]
|
||||||
|
rope_scaling_type: Optional[int]
|
||||||
|
pooling_type: int
|
||||||
|
rope_freq_base: float
|
||||||
|
rope_freq_scale: float
|
||||||
|
yarn_ext_factor: float
|
||||||
|
yarn_attn_factor: float
|
||||||
|
yarn_beta_fast: float
|
||||||
|
yarn_beta_slow: float
|
||||||
|
yarn_orig_ctx: int
|
||||||
|
logits_all: bool
|
||||||
|
embedding: bool
|
||||||
|
offload_kqv: bool
|
||||||
|
flash_attn: bool
|
||||||
|
no_perf: bool
|
||||||
|
last_n_tokens_size: int
|
||||||
|
lora_base: Optional[str]
|
||||||
|
lora_scale: float
|
||||||
|
lora_path: Optional[str]
|
||||||
|
numa: Union[bool, int]
|
||||||
|
chat_format: Optional[str]
|
||||||
|
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler]
|
||||||
|
draft_model: Optional[Any] # LlamaDraftModel not exposed by llama_cpp
|
||||||
|
tokenizer: Optional[Any] # BaseLlamaTokenizer not exposed by llama_cpp
|
||||||
|
type_k: Optional[int]
|
||||||
|
type_v: Optional[int]
|
||||||
|
spm_infill: bool
|
||||||
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||||
|
"""Chat completion client for LlamaCpp models.
|
||||||
|
To use this client, you must install the `llama-cpp` extra:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install "autogen-ext[llama-cpp]"
|
||||||
|
|
||||||
|
This client allows you to interact with LlamaCpp models, either by specifying a local model path or by downloading a model from Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (optional, str): The path to the LlamaCpp model file. Required if repo_id and filename are not provided.
|
||||||
|
repo_id (optional, str): The Hugging Face Hub repository ID. Required if model_path is not provided.
|
||||||
|
filename (optional, str): The filename of the model within the Hugging Face Hub repository. Required if model_path is not provided.
|
||||||
|
n_gpu_layers (optional, int): The number of layers to put on the GPU.
|
||||||
|
n_ctx (optional, int): The context size.
|
||||||
|
n_batch (optional, int): The batch size.
|
||||||
|
verbose (optional, bool): Whether to print verbose output.
|
||||||
|
model_info (optional, ModelInfo): The capabilities of the model. Defaults to a ModelInfo instance with function_calling set to True.
|
||||||
|
**kwargs: Additional parameters to pass to the Llama class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
The following code snippet shows how to use the client with a local model file:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from autogen_core.models import UserMessage
|
||||||
|
from autogen_ext.models.llama_cpp import LlamaCppChatCompletionClient
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
llama_client = LlamaCppChatCompletionClient(model_path="/path/to/your/model.gguf")
|
||||||
|
result = await llama_client.create([UserMessage(content="What is the capital of France?", source="user")])
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
The following code snippet shows how to use the client with a model from Hugging Face Hub:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from autogen_core.models import UserMessage
|
||||||
|
from autogen_ext.models.llama_cpp import LlamaCppChatCompletionClient
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
llama_client = LlamaCppChatCompletionClient(
|
||||||
|
repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||||
|
)
|
||||||
|
result = await llama_client.create([UserMessage(content="What is the capital of France?", source="user")])
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_info: Optional[ModelInfo] = None,
|
||||||
|
**kwargs: Unpack[LlamaCppParams],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the LlamaCpp client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
validate_model_info(model_info)
|
||||||
|
|
||||||
|
if "repo_id" in kwargs and "filename" in kwargs and kwargs["repo_id"] and kwargs["filename"]:
|
||||||
|
repo_id: str = cast(str, kwargs.pop("repo_id"))
|
||||||
|
filename: str = cast(str, kwargs.pop("filename"))
|
||||||
|
pretrained = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) # type: ignore
|
||||||
|
assert isinstance(pretrained, Llama)
|
||||||
|
self.llm = pretrained
|
||||||
|
|
||||||
|
elif "model_path" in kwargs:
|
||||||
|
self.llm = Llama(**kwargs) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
else:
|
||||||
|
raise ValueError("Please provide model_path if ... or provide repo_id and filename if ....")
|
||||||
|
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
|
|
||||||
|
async def create(
|
||||||
|
self,
|
||||||
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
|
# None means do not override the default
|
||||||
|
# A value means to override the client default - often specified in the constructor
|
||||||
|
json_output: Optional[bool] = None,
|
||||||
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
) -> CreateResult:
|
||||||
|
# Convert LLMMessage objects to dictionaries with 'role' and 'content'
|
||||||
|
# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = []
|
||||||
|
converted_messages: list[
|
||||||
|
ChatCompletionRequestSystemMessage
|
||||||
|
| ChatCompletionRequestUserMessage
|
||||||
|
| ChatCompletionRequestAssistantMessage
|
||||||
|
| ChatCompletionRequestUserMessage
|
||||||
|
| ChatCompletionRequestToolMessage
|
||||||
|
| ChatCompletionRequestFunctionMessage
|
||||||
|
] = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, SystemMessage):
|
||||||
|
converted_messages.append({"role": "system", "content": msg.content})
|
||||||
|
elif isinstance(msg, UserMessage) and isinstance(msg.content, str):
|
||||||
|
converted_messages.append({"role": "user", "content": msg.content})
|
||||||
|
elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str):
|
||||||
|
converted_messages.append({"role": "assistant", "content": msg.content})
|
||||||
|
elif (
|
||||||
|
isinstance(msg, SystemMessage) or isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage)
|
||||||
|
) and isinstance(msg.content, list):
|
||||||
|
raise ValueError("Multi-part messages such as those containing images are currently not supported.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||||
|
|
||||||
|
if self.model_info["function_calling"]:
|
||||||
|
response = self.llm.create_chat_completion(
|
||||||
|
messages=converted_messages, tools=convert_tools(tools), stream=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.llm.create_chat_completion(messages=converted_messages, stream=False)
|
||||||
|
|
||||||
|
if not isinstance(response, dict):
|
||||||
|
raise ValueError("Unexpected response type from LlamaCpp model.")
|
||||||
|
|
||||||
|
self._total_usage["prompt_tokens"] += response["usage"]["prompt_tokens"]
|
||||||
|
self._total_usage["completion_tokens"] += response["usage"]["completion_tokens"]
|
||||||
|
|
||||||
|
# Parse the response
|
||||||
|
response_tool_calls: ChatCompletionTool | None = None
|
||||||
|
response_text: str | None = None
|
||||||
|
if "choices" in response and len(response["choices"]) > 0:
|
||||||
|
if "message" in response["choices"][0]:
|
||||||
|
response_text = response["choices"][0]["message"]["content"]
|
||||||
|
if "tool_calls" in response["choices"][0]:
|
||||||
|
response_tool_calls = response["choices"][0]["tool_calls"] # type: ignore
|
||||||
|
|
||||||
|
content: List[FunctionCall] | str = ""
|
||||||
|
thought: str | None = None
|
||||||
|
if response_tool_calls:
|
||||||
|
content = []
|
||||||
|
for tool_call in response_tool_calls:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
raise ValueError("Unexpected tool call type from LlamaCpp model.")
|
||||||
|
content.append(
|
||||||
|
FunctionCall(
|
||||||
|
id=tool_call["id"],
|
||||||
|
arguments=tool_call["function"]["arguments"],
|
||||||
|
name=normalize_name(tool_call["function"]["name"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if response_text and len(response_text) > 0:
|
||||||
|
thought = response_text
|
||||||
|
else:
|
||||||
|
if response_text:
|
||||||
|
content = response_text
|
||||||
|
|
||||||
|
# Detect tool usage in the response
|
||||||
|
if not response_tool_calls and not response_text:
|
||||||
|
logger.debug("DEBUG: No response text found. Returning empty response.")
|
||||||
|
return CreateResult(
|
||||||
|
content="", usage=RequestUsage(prompt_tokens=0, completion_tokens=0), finish_reason="stop", cached=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a CreateResult object
|
||||||
|
if "finish_reason" in response["choices"][0]:
|
||||||
|
finish_reason = response["choices"][0]["finish_reason"]
|
||||||
|
else:
|
||||||
|
finish_reason = "unknown"
|
||||||
|
if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"):
|
||||||
|
finish_reason = "unknown"
|
||||||
|
create_result = CreateResult(
|
||||||
|
content=content,
|
||||||
|
thought=thought,
|
||||||
|
usage=cast(RequestUsage, response["usage"]),
|
||||||
|
finish_reason=normalize_stop_reason(finish_reason), # type: ignore
|
||||||
|
cached=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we are running in the context of a handler we can get the agent_id
|
||||||
|
try:
|
||||||
|
agent_id = MessageHandlerContext.agent_id()
|
||||||
|
except RuntimeError:
|
||||||
|
agent_id = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
LLMCallEvent(
|
||||||
|
messages=cast(List[Dict[str, Any]], converted_messages),
|
||||||
|
response=create_result.model_dump(),
|
||||||
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||||
|
completion_tokens=response["usage"]["completion_tokens"],
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return create_result
|
||||||
|
|
||||||
|
async def create_stream(
|
||||||
|
self,
|
||||||
|
messages: Sequence[LLMMessage],
|
||||||
|
*,
|
||||||
|
tools: Sequence[Tool | ToolSchema] = [],
|
||||||
|
# None means do not override the default
|
||||||
|
# A value means to override the client default - often specified in the constructor
|
||||||
|
json_output: Optional[bool] = None,
|
||||||
|
extra_create_args: Mapping[str, Any] = {},
|
||||||
|
cancellation_token: Optional[CancellationToken] = None,
|
||||||
|
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||||
|
raise NotImplementedError("Stream not yet implemented for LlamaCppChatCompletionClient")
|
||||||
|
yield ""
|
||||||
|
|
||||||
|
# Implement abstract methods
|
||||||
|
def actual_usage(self) -> RequestUsage:
|
||||||
|
return RequestUsage(
|
||||||
|
prompt_tokens=self._total_usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=self._total_usage.get("completion_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ModelInfo:
|
||||||
|
return self.model_info
|
||||||
|
|
||||||
|
def count_tokens(
|
||||||
|
self,
|
||||||
|
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> int:
|
||||||
|
total = 0
|
||||||
|
for msg in messages:
|
||||||
|
# Use the Llama model's tokenizer to encode the content
|
||||||
|
tokens = self.llm.tokenize(str(msg.content).encode("utf-8"))
|
||||||
|
total += len(tokens)
|
||||||
|
return total
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_info(self) -> ModelInfo:
|
||||||
|
return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True)
|
||||||
|
|
||||||
|
def remaining_tokens(
|
||||||
|
self,
|
||||||
|
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> int:
|
||||||
|
used_tokens = self.count_tokens(messages)
|
||||||
|
return max(self.llm.n_ctx() - used_tokens, 0)
|
||||||
|
|
||||||
|
def total_usage(self) -> RequestUsage:
|
||||||
|
return RequestUsage(
|
||||||
|
prompt_tokens=self._total_usage.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=self._total_usage.get("completion_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close the LlamaCpp client.
|
||||||
|
"""
|
||||||
|
self.llm.close()
|
|
@ -3,8 +3,8 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import platform
|
import platform
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import venv
|
import venv
|
||||||
|
@ -18,7 +18,6 @@ from autogen_core import CancellationToken
|
||||||
from autogen_core.code_executor import CodeBlock
|
from autogen_core.code_executor import CodeBlock
|
||||||
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||||
|
|
||||||
|
|
||||||
HAS_POWERSHELL: bool = platform.system() == "Windows" and (
|
HAS_POWERSHELL: bool = platform.system() == "Windows" and (
|
||||||
shutil.which("powershell") is not None or shutil.which("pwsh") is not None
|
shutil.which("powershell") is not None or shutil.which("pwsh") is not None
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,205 @@
|
||||||
|
import contextlib
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING, Any, ContextManager, Generator, List, Sequence, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# from autogen_agentchat.agents import AssistantAgent
|
||||||
|
# from autogen_agentchat.messages import TextMessage
|
||||||
|
# from autogen_core import CancellationToken
|
||||||
|
from autogen_core.models import RequestUsage, SystemMessage, UserMessage
|
||||||
|
|
||||||
|
# from autogen_core.tools import FunctionTool
|
||||||
|
try:
|
||||||
|
from llama_cpp import ChatCompletionMessageToolCalls
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||||
|
except ImportError:
|
||||||
|
# If llama_cpp is not installed, we can't run the tests.
|
||||||
|
pytest.skip("Skipping LlamaCppChatCompletionClient tests: llama-cpp-python not installed", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Fake Llama class to simulate responses
|
||||||
|
class FakeLlama:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
**_: Any,
|
||||||
|
) -> None:
|
||||||
|
self.model_path = model_path
|
||||||
|
self.n_ctx = lambda: 1024
|
||||||
|
|
||||||
|
# Added tokenize method for testing purposes.
|
||||||
|
def tokenize(self, b: bytes) -> list[int]:
|
||||||
|
return list(b)
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self, messages: Any, tools: List[ChatCompletionMessageToolCalls] | None, stream: bool = False
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# Return fake non-streaming response.
|
||||||
|
|
||||||
|
return {
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 2},
|
||||||
|
"choices": [{"message": {"content": "Fake response"}}],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __call__(self, prompt: str, stream: bool = True) -> Generator[dict[str, Any], None, None]:
|
||||||
|
# Yield fake streaming tokens.
|
||||||
|
yield {"choices": [{"text": "Hello "}]}
|
||||||
|
yield {"choices": [{"text": "World"}]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def get_completion_client(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> "Generator[type[LlamaCppChatCompletionClient], None, None]":
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setattr("llama_cpp.Llama", FakeLlama)
|
||||||
|
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||||
|
|
||||||
|
yield LlamaCppChatCompletionClient
|
||||||
|
sys.modules.pop("autogen_ext.models.llama_cpp._llama_cpp_completion_client", None)
|
||||||
|
sys.modules.pop("llama_cpp", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llama_cpp_create(get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]") -> None:
|
||||||
|
with get_completion_client as Client:
|
||||||
|
client = Client(model_path="dummy")
|
||||||
|
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||||
|
SystemMessage(content="Test system"),
|
||||||
|
UserMessage(content="Test user", source="user"),
|
||||||
|
]
|
||||||
|
result = await client.create(messages=messages)
|
||||||
|
assert result.content == "Fake response"
|
||||||
|
usage: RequestUsage = result.usage
|
||||||
|
assert usage.prompt_tokens == 1
|
||||||
|
assert usage.completion_tokens == 2
|
||||||
|
assert result.finish_reason in ("stop", "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_llama_cpp_create_stream(
|
||||||
|
# get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]",
|
||||||
|
# ) -> None:
|
||||||
|
# with get_completion_client as Client:
|
||||||
|
# client = Client(filename="dummy")
|
||||||
|
# messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||||
|
# SystemMessage(content="Test system"),
|
||||||
|
# UserMessage(content="Test user", source="user"),
|
||||||
|
# ]
|
||||||
|
# collected = ""
|
||||||
|
# async for token in client.create_stream(messages=messages):
|
||||||
|
# collected += token
|
||||||
|
# assert collected == "Hello World"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_invalid_message(
|
||||||
|
get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]",
|
||||||
|
) -> None:
|
||||||
|
with get_completion_client as Client:
|
||||||
|
client = Client(model_path="dummy")
|
||||||
|
# Pass an unsupported message type (integer) to trigger ValueError.
|
||||||
|
with pytest.raises(ValueError, match="Unsupported message type"):
|
||||||
|
await client.create(messages=[123]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_and_remaining_tokens(
|
||||||
|
get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]", monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
with get_completion_client as Client:
|
||||||
|
client = Client(model_path="dummy")
|
||||||
|
msg = SystemMessage(content="Test")
|
||||||
|
# count_tokens should count the bytes
|
||||||
|
token_count = client.count_tokens([msg])
|
||||||
|
# Since "Test" encoded is 4 bytes, expect 4 tokens.
|
||||||
|
assert token_count >= 4
|
||||||
|
remaining = client.remaining_tokens([msg])
|
||||||
|
# remaining should be (1024 - token_count); ensure non-negative.
|
||||||
|
assert remaining == max(1024 - token_count, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llama_cpp_integration_non_streaming() -> None:
|
||||||
|
if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||||
|
pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||||
|
|
||||||
|
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||||
|
|
||||||
|
client = LlamaCppChatCompletionClient(
|
||||||
|
repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||||
|
)
|
||||||
|
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
UserMessage(content="Hello, how are you?", source="user"),
|
||||||
|
]
|
||||||
|
result = await client.create(messages=messages)
|
||||||
|
assert isinstance(result.content, str) and len(result.content.strip()) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_llama_cpp_integration_streaming() -> None:
|
||||||
|
# if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||||
|
# pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||||
|
|
||||||
|
# from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||||
|
# client = LlamaCppChatCompletionClient(
|
||||||
|
# repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||||
|
# )
|
||||||
|
# messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||||
|
# SystemMessage(content="You are a helpful assistant."),
|
||||||
|
# UserMessage(content="Please stream your response.", source="user"),
|
||||||
|
# ]
|
||||||
|
# collected = ""
|
||||||
|
# async for token in client.create_stream(messages=messages):
|
||||||
|
# collected += token
|
||||||
|
# assert isinstance(collected, str) and len(collected.strip()) > 0
|
||||||
|
|
||||||
|
# Commented out tool use as this functionality is not yet implemented for Phi-4.
|
||||||
|
# Define tools (functions) for the AssistantAgent
|
||||||
|
# def add(num1: int, num2: int) -> int:
|
||||||
|
# """Add two numbers together"""
|
||||||
|
# return num1 + num2
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.asyncio
|
||||||
|
# async def test_llama_cpp_integration_tool_use() -> None:
|
||||||
|
# if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||||
|
# pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||||
|
|
||||||
|
# from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||||
|
|
||||||
|
# model_client = LlamaCppChatCompletionClient(
|
||||||
|
# repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Initialize the AssistantAgent
|
||||||
|
# assistant = AssistantAgent(
|
||||||
|
# name="assistant",
|
||||||
|
# system_message=("You can add two numbers together using the `add` function. "),
|
||||||
|
# model_client=model_client,
|
||||||
|
# tools=[
|
||||||
|
# FunctionTool(
|
||||||
|
# add,
|
||||||
|
# description="Add two numbers together. The first argument is num1 and second is num2. The return value is num1 + num2",
|
||||||
|
# )
|
||||||
|
# ],
|
||||||
|
# reflect_on_tool_use=True, # Reflect on tool results
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Test the tool
|
||||||
|
# response = await assistant.on_messages(
|
||||||
|
# [
|
||||||
|
# TextMessage(content="add 3 and 4", source="user"),
|
||||||
|
# ],
|
||||||
|
# CancellationToken(),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# assert "7" in response.chat_message.content
|
|
@ -628,6 +628,9 @@ jupyter-executor = [
|
||||||
langchain = [
|
langchain = [
|
||||||
{ name = "langchain-core" },
|
{ name = "langchain-core" },
|
||||||
]
|
]
|
||||||
|
llama-cpp = [
|
||||||
|
{ name = "llama-cpp-python" },
|
||||||
|
]
|
||||||
magentic-one = [
|
magentic-one = [
|
||||||
{ name = "autogen-agentchat" },
|
{ name = "autogen-agentchat" },
|
||||||
{ name = "markitdown" },
|
{ name = "markitdown" },
|
||||||
|
@ -735,6 +738,7 @@ requires-dist = [
|
||||||
{ name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" },
|
{ name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" },
|
||||||
{ name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" },
|
{ name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.2" },
|
||||||
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
|
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
|
||||||
|
{ name = "llama-cpp-python", marker = "extra == 'llama-cpp'", specifier = ">=0.1.9" },
|
||||||
{ name = "markitdown", marker = "extra == 'file-surfer'", specifier = "~=0.0.1" },
|
{ name = "markitdown", marker = "extra == 'file-surfer'", specifier = "~=0.0.1" },
|
||||||
{ name = "markitdown", marker = "extra == 'magentic-one'", specifier = "~=0.0.1" },
|
{ name = "markitdown", marker = "extra == 'magentic-one'", specifier = "~=0.0.1" },
|
||||||
{ name = "markitdown", marker = "extra == 'web-surfer'", specifier = "~=0.0.1" },
|
{ name = "markitdown", marker = "extra == 'web-surfer'", specifier = "~=0.0.1" },
|
||||||
|
@ -3501,6 +3505,18 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/43/b1/9355547c3b9043ba2821e7797f322c753dfa4d2a3da7bb05690fce689eaa/llama_cloud-0.1.11-py3-none-any.whl", hash = "sha256:b703765d03783a5a0fc57a52adc9892f8b91b0c19bbecb85a54ad4e813342951", size = 250609 },
|
{ url = "https://files.pythonhosted.org/packages/43/b1/9355547c3b9043ba2821e7797f322c753dfa4d2a3da7bb05690fce689eaa/llama_cloud-0.1.11-py3-none-any.whl", hash = "sha256:b703765d03783a5a0fc57a52adc9892f8b91b0c19bbecb85a54ad4e813342951", size = 250609 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "llama-cpp-python"
|
||||||
|
version = "0.3.7"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "diskcache" },
|
||||||
|
{ name = "jinja2" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/a6/38/7a47b1fb1d83eaddd86ca8ddaf20f141cbc019faf7b425283d8e5ef710e5/llama_cpp_python-0.3.7.tar.gz", hash = "sha256:0566a0dcc0f38005c4093309a87f67c2452449522e3e17e15cd735a62957894c", size = 66715891 }
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-index"
|
name = "llama-index"
|
||||||
version = "0.12.14"
|
version = "0.12.14"
|
||||||
|
|
Loading…
Reference in New Issue