diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index d097bc8..c1e6938 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -1,33 +1,29 @@ use crate::completion::{CompletionError, CompletionRequest}; -use crate::message::{ToolCall, ToolFunction}; -use crate::providers::cohere::completion::{AssistantContent, BilledUnits, Message, Usage}; +use crate::providers::cohere::completion::Usage; use crate::providers::cohere::CompletionModel; use crate::streaming::{RawStreamingChoice, StreamingCompletionModel}; use crate::{json_utils, streaming}; use async_stream::stream; use futures::StreamExt; use serde::Deserialize; -use serde_json::{json, Value}; -use std::collections::HashMap; -use std::future::Future; +use serde_json::json; #[derive(Debug, Deserialize)] #[serde(rename_all = "kebab-case", tag = "type")] -pub enum StreamingEvent { +enum StreamingEvent { MessageStart, ContentStart, ContentDelta { delta: Option }, ContentEnd, - ToolPlan { delta: Option }, + ToolPlan, ToolCallStart { delta: Option }, ToolCallDelta { delta: Option }, - ToolCallEnd { delta: Option }, + ToolCallEnd, MessageEnd { delta: Option }, } #[derive(Debug, Deserialize)] struct MessageContentDelta { - r#type: Option, text: Option, } @@ -40,14 +36,12 @@ struct MessageToolFunctionDelta { #[derive(Debug, Deserialize)] struct MessageToolCallDelta { id: Option, - r#type: Option, function: Option, } #[derive(Debug, Deserialize)] struct MessageDelta { content: Option, - tool_plan: Option, tool_calls: Option, } @@ -58,7 +52,6 @@ struct Delta { #[derive(Debug, Deserialize)] struct MessageEndDelta { - finish_reason: Option, usage: Option, } @@ -119,9 +112,9 @@ impl StreamingCompletionModel for CompletionModel { let Some(line) = line.strip_prefix("data: ") else { continue; }; - + let event = { - let result = serde_json::from_str::(&line); + let result = serde_json::from_str::(line); let Ok(event) = result else { continue; @@ -129,13 +122,13 @@ impl StreamingCompletionModel for CompletionModel { event }; - + match event { StreamingEvent::ContentDelta { delta: Some(delta) } => { let Some(message) = &delta.message else { continue; }; let Some(content) = &message.content else { continue; }; let Some(text) = &content.text else { continue; }; - + yield Ok(RawStreamingChoice::Message(text.clone())); }, StreamingEvent::MessageEnd {delta: Some(delta)} => { @@ -145,7 +138,7 @@ impl StreamingCompletionModel for CompletionModel { }, StreamingEvent::ToolCallStart { delta: Some(delta)} => { // Skip the delta if there's any missing information, - // though this *should* all be present + // though this *should* all be present let Some(message) = &delta.message else { continue; }; let Some(tool_calls) = &message.tool_calls else { continue; }; let Some(id) = tool_calls.id.clone() else { continue; }; @@ -173,15 +166,15 @@ impl StreamingCompletionModel for CompletionModel { }, StreamingEvent::ToolCallEnd { .. } => { let Some(tc) = current_tool_call.clone() else { continue; }; - + let Ok(args) = serde_json::from_str(&tc.2) else { continue; }; - + yield Ok(RawStreamingChoice::ToolCall( tc.0, tc.1, args )); - + current_tool_call = None; }, _ => {} diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index c4bcdef..ac68953 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -17,7 +17,6 @@ use crate::message::AssistantContent; use crate::OneOrMany; use futures::{Stream, StreamExt}; use std::boxed::Box; -use std::fmt::{Display, Formatter}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll};