rig/rig-core/examples/gemini_streaming_with_tools.rs

112 lines
3.3 KiB
Rust

use anyhow::Result;
use rig::streaming::stream_to_stdout;
use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
},
"required": ["x", "y"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
},
"required": ["x", "y"]
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let calculator_agent = providers::gemini::Client::from_env()
.agent(providers::gemini::completion::GEMINI_1_5_FLASH)
.preamble(
"You are a calculator here to help the user perform arithmetic
operations. Use the tools provided to answer the user's question.
make your answer long, so we can test the streaming functionality,
like 20 words",
)
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
println!("Calculate 2 - 5");
let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await?;
stream_to_stdout(calculator_agent, &mut stream).await?;
Ok(())
}