Add model_context property to AssistantAgent (#6072)

AssistantAgent initiation allows one to pass in a model_context, but
there isn't a "public: way to get the existing model_context created by
default.
This commit is contained in:
jspv 2025-03-22 23:21:29 -04:00 committed by GitHub
parent e28738ac6f
commit fc2c9978fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 0 deletions

View File

@ -702,6 +702,13 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
message_types.append(ToolCallSummaryMessage) message_types.append(ToolCallSummaryMessage)
return tuple(message_types) return tuple(message_types)
@property
def model_context(self) -> ChatCompletionContext:
"""
The model context in use by the agent.
"""
return self._model_context
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token): async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response): if isinstance(message, Response):

View File

@ -615,6 +615,8 @@ async def test_model_context(monkeypatch: pytest.MonkeyPatch) -> None:
] ]
await agent.run(task=messages) await agent.run(task=messages)
# Check that the model_context property returns the correct internal context
assert agent.model_context == model_context
# Check if the mock client is called with only the last two messages. # Check if the mock client is called with only the last two messages.
assert len(model_client.create_calls) == 1 assert len(model_client.create_calls) == 1
# 2 message from the context + 1 system message # 2 message from the context + 1 system message