mirror of https://github.com/0xplaygrounds/rig
chore: comment the new streaming code
This commit is contained in:
parent
ce37d44b15
commit
6dcabc1e0e
|
@ -30,7 +30,8 @@ pub enum RawStreamingChoice<R: Clone> {
|
|||
/// A tool call response chunk
|
||||
ToolCall(String, String, serde_json::Value),
|
||||
|
||||
/// The final response object
|
||||
/// The final response object, must be yielded if you want the
|
||||
/// `response` field to be populated on the `StreamingCompletionResponse`
|
||||
FinalResponse(R),
|
||||
}
|
||||
|
||||
|
@ -42,11 +43,18 @@ pub type StreamingResult<R> =
|
|||
pub type StreamingResult<R> =
|
||||
Pin<Box<dyn Stream<Item = Result<RawStreamingChoice<R>, CompletionError>>>>;
|
||||
|
||||
/// The response from a streaming completion request;
|
||||
/// message and response are populated at the end of the
|
||||
/// `inner` stream.
|
||||
pub struct StreamingCompletionResponse<R: Clone + Unpin> {
|
||||
inner: StreamingResult<R>,
|
||||
text: String,
|
||||
tool_calls: Vec<(String, String, serde_json::Value)>,
|
||||
/// The final aggregated message from the stream
|
||||
/// contains all text and tool calls generated
|
||||
pub message: Message,
|
||||
/// The final response from the stream, may be `None`
|
||||
/// if the provider didn't yield it during the stream
|
||||
pub response: Option<R>,
|
||||
}
|
||||
|
||||
|
@ -71,12 +79,15 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
|||
match stream.inner.as_mut().poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(None) => {
|
||||
// This is run at the end of the inner stream to collect all tokens into
|
||||
// a single unified `Message`.
|
||||
let mut content = vec![];
|
||||
|
||||
stream.tool_calls.iter().for_each(|(n, d, a)| {
|
||||
content.push(AssistantContent::tool_call(n, d, a.clone()));
|
||||
});
|
||||
|
||||
// This is required to ensure there's always at least one item in the content
|
||||
if content.is_empty() || !stream.text.is_empty() {
|
||||
content.insert(0, AssistantContent::text(stream.text.clone()));
|
||||
}
|
||||
|
@ -91,16 +102,21 @@ impl<R: Clone + Unpin> Stream for StreamingCompletionResponse<R> {
|
|||
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||
Poll::Ready(Some(Ok(choice))) => match choice {
|
||||
RawStreamingChoice::Message(text) => {
|
||||
// Forward the streaming tokens to the outer stream
|
||||
// and concat the text together
|
||||
stream.text = format!("{}{}", stream.text, text.clone());
|
||||
Poll::Ready(Some(Ok(AssistantContent::text(text))))
|
||||
}
|
||||
RawStreamingChoice::ToolCall(id, name, args) => {
|
||||
// Keep track of each tool call to aggregate the final message later
|
||||
// and pass it to the outer stream
|
||||
stream
|
||||
.tool_calls
|
||||
.push((id.clone(), name.clone(), args.clone()));
|
||||
Poll::Ready(Some(Ok(AssistantContent::tool_call(id, name, args))))
|
||||
}
|
||||
RawStreamingChoice::FinalResponse(response) => {
|
||||
// Set the final response field and return the next item in the stream
|
||||
stream.response = Some(response);
|
||||
|
||||
stream.poll_next_unpin(cx)
|
||||
|
|
Loading…
Reference in New Issue