diff --git a/rig-core/examples/openrouter_streaming_with_tools.rs b/rig-core/examples/openrouter_streaming_with_tools.rs new file mode 100644 index 0000000..96a2256 --- /dev/null +++ b/rig-core/examples/openrouter_streaming_with_tools.rs @@ -0,0 +1,118 @@ +use anyhow::Result; +use rig::streaming::stream_to_stdout; +use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"] + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"] + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt().init(); + // Create agent with a single context prompt and two tools + let calculator_agent = providers::openrouter::Client::from_env() + .agent(providers::openrouter::GEMINI_FLASH_2_0) + .preamble( + "You are a calculator here to help the user perform arithmetic + operations. Use the tools provided to answer the user's question. + make your answer long, so we can test the streaming functionality, + like 20 words", + ) + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + println!("Calculate 2 - 5"); + let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?; + stream_to_stdout(calculator_agent, &mut stream).await?; + + if let Some(response) = stream.response { + println!("Usage: {:?}", response.usage) + }; + + println!("Message: {:?}", stream.choice); + + Ok(()) +} \ No newline at end of file