mirror of https://github.com/microsoft/autogen.git
Fix streaming + tool bug in Ollama (#6193)
Fix a bug that caused tool calls to be truncated in OllamaChatCompletionClient when streaming is on.
This commit is contained in:
parent
5508cc7a43
commit
d4ac2ca6de
|
@ -646,14 +646,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||
content: Union[str, List[FunctionCall]]
|
||||
thought: Optional[str] = None
|
||||
if result.message.tool_calls is not None:
|
||||
# TODO: What are possible values for done_reason?
|
||||
if result.done_reason != "tool_calls":
|
||||
warnings.warn(
|
||||
f"Finish reason mismatch: {result.done_reason} != tool_calls "
|
||||
"when tool_calls are present. Finish reason may not be accurate. "
|
||||
"This may be due to the API used that is not returning the correct finish reason.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if result.message.content is not None and result.message.content != "":
|
||||
thought = result.message.content
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
|
@ -760,9 +752,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||
content_chunks.append(chunk.message.content)
|
||||
if len(chunk.message.content) > 0:
|
||||
yield chunk.message.content
|
||||
continue
|
||||
|
||||
# Otherwise, get tool calls
|
||||
# Get tool calls
|
||||
if chunk.message.tool_calls is not None:
|
||||
full_tool_calls.extend(
|
||||
[
|
||||
|
@ -796,9 +787,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
|||
else:
|
||||
prompt_tokens = 0
|
||||
|
||||
if stop_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
content: Union[str, List[FunctionCall]]
|
||||
thought: Optional[str] = None
|
||||
|
||||
|
|
|
@ -206,6 +206,77 @@ async def test_create_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert create_result.usage.completion_tokens == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stream_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
model = "llama3.2"
|
||||
content_raw = "Hello world! This is a test response. Test response."
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||
assert "stream" in kwargs
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
|
||||
# Simulate streaming by yielding chunks of the response
|
||||
for chunk in chunks[:-1]:
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=False,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunk,
|
||||
),
|
||||
)
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
content=chunks[-1],
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 2},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
return _mock_stream()
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
stream = client.create_stream(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
tools=[add_tool],
|
||||
)
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, list)
|
||||
assert len(chunks[-1].content) > 0
|
||||
assert isinstance(chunks[-1].content[0], FunctionCall)
|
||||
assert chunks[-1].content[0].name == add_tool.name
|
||||
assert chunks[-1].content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert chunks[-1].usage is not None
|
||||
assert chunks[-1].usage.prompt_tokens == 10
|
||||
assert chunks[-1].usage.completion_tokens == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class ResponseType(BaseModel):
|
||||
|
@ -541,7 +612,6 @@ async def test_ollama_create_structured_output_with_tools(
|
|||
assert ResponseType.model_validate_json(create_result.thought)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: Fix streaming with tools")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
|
||||
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
||||
|
@ -569,7 +639,7 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
|
|||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
assert create_result.finish_reason == "stop"
|
||||
|
||||
execution_result = FunctionExecutionResult(
|
||||
content="4",
|
||||
|
|
Loading…
Reference in New Issue