rig/rig-core/examples/huggingface_subproviders.rs

150 lines
4.3 KiB
Rust

use std::env;
use rig::{
agent::AgentBuilder,
completion::{Prompt, ToolDefinition},
providers::{self, huggingface::SubProvider},
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let models = [
("deepseek-ai/DeepSeek-V3", SubProvider::Together),
("meta-llama/Llama-3.1-8B-Instruct", SubProvider::HFInference),
("Meta-Llama-3.1-8B-Instruct", SubProvider::SambaNova),
("deepseek-v3", SubProvider::Fireworks),
("Qwen/Qwen2.5-32B-Instruct", SubProvider::Nebius),
];
for (model, sub_provider) in models {
tools(model, sub_provider).await?;
}
Ok(())
}
fn client(sub_provider: SubProvider) -> providers::huggingface::Client {
let api_key = &env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY not set");
providers::huggingface::ClientBuilder::new(&api_key)
.sub_provider(sub_provider)
.build()
}
/// Create a partial huggingface agent (deepseek R1)
fn partial_agent(
model: &str,
sub_provider: SubProvider,
) -> AgentBuilder<providers::huggingface::completion::CompletionModel> {
let client = client(sub_provider);
client.agent(model)
}
/// Create an huggingface agent (deepseek R1) with tools
/// Based upon the `tools` example
///
/// This example creates a calculator agent with two tools: add and subtract
async fn tools(model: &str, sub_provider: SubProvider) -> Result<(), anyhow::Error> {
// Create agent with a single context prompt and two tools
let calculator_agent = partial_agent(model, sub_provider.clone())
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();
// Prompt the agent and print the response
println!("Asking {} on {:?} to Calculate 2 - 5", model, sub_provider);
println!(
"Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);
Ok(())
}
#[derive(Deserialize)]
struct OperationArgs {
x: f32,
y: f32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = f32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = f32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to subtract from"
},
"y": {
"type": "number",
"description": "The number to subtract"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}