Fix streaming + tool bug in Ollama (#6193)

Fix a bug that caused tool calls to be truncated in
OllamaChatCompletionClient when streaming is on.
This commit is contained in:
Eric Zhu 2025-04-03 14:56:01 -07:00 committed by GitHub
parent 5508cc7a43
commit d4ac2ca6de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 15 deletions

View File

@ -646,14 +646,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
thought: Optional[str] = None thought: Optional[str] = None
if result.message.tool_calls is not None: if result.message.tool_calls is not None:
# TODO: What are possible values for done_reason?
if result.done_reason != "tool_calls":
warnings.warn(
f"Finish reason mismatch: {result.done_reason} != tool_calls "
"when tool_calls are present. Finish reason may not be accurate. "
"This may be due to the API used that is not returning the correct finish reason.",
stacklevel=2,
)
if result.message.content is not None and result.message.content != "": if result.message.content is not None and result.message.content != "":
thought = result.message.content thought = result.message.content
# NOTE: If OAI response type changes, this will need to be updated # NOTE: If OAI response type changes, this will need to be updated
@ -760,9 +752,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
content_chunks.append(chunk.message.content) content_chunks.append(chunk.message.content)
if len(chunk.message.content) > 0: if len(chunk.message.content) > 0:
yield chunk.message.content yield chunk.message.content
continue
# Otherwise, get tool calls # Get tool calls
if chunk.message.tool_calls is not None: if chunk.message.tool_calls is not None:
full_tool_calls.extend( full_tool_calls.extend(
[ [
@ -796,9 +787,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
else: else:
prompt_tokens = 0 prompt_tokens = 0
if stop_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
content: Union[str, List[FunctionCall]] content: Union[str, List[FunctionCall]]
thought: Optional[str] = None thought: Optional[str] = None

View File

@ -206,6 +206,77 @@ async def test_create_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert create_result.usage.completion_tokens == 12 assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_create_stream_tools(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
content_raw = "Hello world! This is a test response. Test response."
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
# Simulate streaming by yielding chunks of the response
for chunk in chunks[:-1]:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
content=chunks[-1],
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[
UserMessage(content="hi", source="user"),
],
tools=[add_tool],
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, list)
assert len(chunks[-1].content) > 0
assert isinstance(chunks[-1].content[0], FunctionCall)
assert chunks[-1].content[0].name == add_tool.name
assert chunks[-1].content[0].arguments == json.dumps({"x": 2, "y": 2})
assert chunks[-1].finish_reason == "stop"
assert chunks[-1].usage is not None
assert chunks[-1].usage.prompt_tokens == 10
assert chunks[-1].usage.completion_tokens == 12
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None: async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
class ResponseType(BaseModel): class ResponseType(BaseModel):
@ -541,7 +612,6 @@ async def test_ollama_create_structured_output_with_tools(
assert ResponseType.model_validate_json(create_result.thought) assert ResponseType.model_validate_json(create_result.thought)
@pytest.mark.skip("TODO: Fix streaming with tools")
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"]) @pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None: async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
@ -569,7 +639,7 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
assert len(create_result.content) > 0 assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall) assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls" assert create_result.finish_reason == "stop"
execution_result = FunctionExecutionResult( execution_result = FunctionExecutionResult(
content="4", content="4",