mirror of https://github.com/microsoft/autogen.git
Fix streaming + tool bug in Ollama (#6193)
Fix a bug that caused tool calls to be truncated in OllamaChatCompletionClient when streaming is on.
This commit is contained in:
parent
5508cc7a43
commit
d4ac2ca6de
|
@ -646,14 +646,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
thought: Optional[str] = None
|
thought: Optional[str] = None
|
||||||
if result.message.tool_calls is not None:
|
if result.message.tool_calls is not None:
|
||||||
# TODO: What are possible values for done_reason?
|
|
||||||
if result.done_reason != "tool_calls":
|
|
||||||
warnings.warn(
|
|
||||||
f"Finish reason mismatch: {result.done_reason} != tool_calls "
|
|
||||||
"when tool_calls are present. Finish reason may not be accurate. "
|
|
||||||
"This may be due to the API used that is not returning the correct finish reason.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if result.message.content is not None and result.message.content != "":
|
if result.message.content is not None and result.message.content != "":
|
||||||
thought = result.message.content
|
thought = result.message.content
|
||||||
# NOTE: If OAI response type changes, this will need to be updated
|
# NOTE: If OAI response type changes, this will need to be updated
|
||||||
|
@ -760,9 +752,8 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||||
content_chunks.append(chunk.message.content)
|
content_chunks.append(chunk.message.content)
|
||||||
if len(chunk.message.content) > 0:
|
if len(chunk.message.content) > 0:
|
||||||
yield chunk.message.content
|
yield chunk.message.content
|
||||||
continue
|
|
||||||
|
|
||||||
# Otherwise, get tool calls
|
# Get tool calls
|
||||||
if chunk.message.tool_calls is not None:
|
if chunk.message.tool_calls is not None:
|
||||||
full_tool_calls.extend(
|
full_tool_calls.extend(
|
||||||
[
|
[
|
||||||
|
@ -796,9 +787,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||||
else:
|
else:
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
|
|
||||||
if stop_reason == "function_call":
|
|
||||||
raise ValueError("Function calls are not supported in this context")
|
|
||||||
|
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
thought: Optional[str] = None
|
thought: Optional[str] = None
|
||||||
|
|
||||||
|
|
|
@ -206,6 +206,77 @@ async def test_create_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
assert create_result.usage.completion_tokens == 12
|
assert create_result.usage.completion_tokens == 12
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_stream_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def add(x: int, y: int) -> str:
|
||||||
|
return str(x + y)
|
||||||
|
|
||||||
|
add_tool = FunctionTool(add, description="Add two numbers")
|
||||||
|
model = "llama3.2"
|
||||||
|
content_raw = "Hello world! This is a test response. Test response."
|
||||||
|
|
||||||
|
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||||
|
assert "stream" in kwargs
|
||||||
|
assert kwargs["stream"] is True
|
||||||
|
|
||||||
|
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||||
|
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
|
||||||
|
# Simulate streaming by yielding chunks of the response
|
||||||
|
for chunk in chunks[:-1]:
|
||||||
|
yield ChatResponse(
|
||||||
|
model=model,
|
||||||
|
done=False,
|
||||||
|
message=Message(
|
||||||
|
role="assistant",
|
||||||
|
content=chunk,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield ChatResponse(
|
||||||
|
model=model,
|
||||||
|
done=True,
|
||||||
|
done_reason="stop",
|
||||||
|
message=Message(
|
||||||
|
content=chunks[-1],
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=[
|
||||||
|
Message.ToolCall(
|
||||||
|
function=Message.ToolCall.Function(
|
||||||
|
name=add_tool.name,
|
||||||
|
arguments={"x": 2, "y": 2},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
prompt_eval_count=10,
|
||||||
|
eval_count=12,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _mock_stream()
|
||||||
|
|
||||||
|
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||||
|
client = OllamaChatCompletionClient(model=model)
|
||||||
|
stream = client.create_stream(
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="hi", source="user"),
|
||||||
|
],
|
||||||
|
tools=[add_tool],
|
||||||
|
)
|
||||||
|
chunks: List[str | CreateResult] = []
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert isinstance(chunks[-1], CreateResult)
|
||||||
|
assert isinstance(chunks[-1].content, list)
|
||||||
|
assert len(chunks[-1].content) > 0
|
||||||
|
assert isinstance(chunks[-1].content[0], FunctionCall)
|
||||||
|
assert chunks[-1].content[0].name == add_tool.name
|
||||||
|
assert chunks[-1].content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||||
|
assert chunks[-1].finish_reason == "stop"
|
||||||
|
assert chunks[-1].usage is not None
|
||||||
|
assert chunks[-1].usage.prompt_tokens == 10
|
||||||
|
assert chunks[-1].usage.completion_tokens == 12
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
class ResponseType(BaseModel):
|
class ResponseType(BaseModel):
|
||||||
|
@ -541,7 +612,6 @@ async def test_ollama_create_structured_output_with_tools(
|
||||||
assert ResponseType.model_validate_json(create_result.thought)
|
assert ResponseType.model_validate_json(create_result.thought)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO: Fix streaming with tools")
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
|
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
|
||||||
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
||||||
|
@ -569,7 +639,7 @@ async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatC
|
||||||
assert len(create_result.content) > 0
|
assert len(create_result.content) > 0
|
||||||
assert isinstance(create_result.content[0], FunctionCall)
|
assert isinstance(create_result.content[0], FunctionCall)
|
||||||
assert create_result.content[0].name == add_tool.name
|
assert create_result.content[0].name == add_tool.name
|
||||||
assert create_result.finish_reason == "function_calls"
|
assert create_result.finish_reason == "stop"
|
||||||
|
|
||||||
execution_result = FunctionExecutionResult(
|
execution_result = FunctionExecutionResult(
|
||||||
content="4",
|
content="4",
|
||||||
|
|
Loading…
Reference in New Issue