mirror of https://github.com/0xplaygrounds/rig
feat: galadriel api integration (redux) (#265)
* feat: galadriel api integration * refactor: add cfg-attr for worker feature
This commit is contained in:
parent
f869ce0a38
commit
3b6692b737
|
@ -0,0 +1,23 @@
|
|||
use rig::{completion::Prompt, providers};
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
// Create Galadriel client
|
||||
let client = providers::galadriel::Client::new(
|
||||
&env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set"),
|
||||
env::var("GALADRIEL_FINE_TUNE_API_KEY").ok().as_deref(),
|
||||
);
|
||||
|
||||
// Create agent with a single context prompt
|
||||
let comedian_agent = client
|
||||
.agent("gpt-4o")
|
||||
.preamble("You are a comedian here to entertain the user using humour and jokes.")
|
||||
.build();
|
||||
|
||||
// Prompt the agent and print the response
|
||||
let response = comedian_agent.prompt("Entertain me!").await?;
|
||||
println!("{}", response);
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,399 @@
|
|||
//! Galadriel API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::galadriel;
|
||||
//!
|
||||
//! let client = galadriel::Client::new("YOUR_API_KEY", None);
|
||||
//! // to use a fine-tuned model
|
||||
//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
|
||||
//!
|
||||
//! let gpt4o = client.completion_model(galadriel::GPT_4O);
|
||||
//! ```
|
||||
use crate::{
|
||||
agent::AgentBuilder,
|
||||
completion::{self, CompletionError, CompletionRequest},
|
||||
extractor::ExtractorBuilder,
|
||||
json_utils,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
// ================================================================
|
||||
// Main Galadriel Client
|
||||
// ================================================================
|
||||
const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new Galadriel client with the given API key and optional fine-tune API key.
|
||||
pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
|
||||
Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
|
||||
}
|
||||
|
||||
/// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key.
|
||||
pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
|
||||
Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
|
||||
}
|
||||
|
||||
pub fn from_url_with_optional_key(
|
||||
api_key: &str,
|
||||
base_url: &str,
|
||||
fine_tune_api_key: Option<&str>,
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
if let Some(key) = fine_tune_api_key {
|
||||
headers.insert(
|
||||
"Fine-Tune-Authorization",
|
||||
format!("Bearer {}", key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
}
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("Galadriel reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
|
||||
/// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
|
||||
/// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
|
||||
let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
|
||||
Self::new(&api_key, fine_tune_api_key.as_deref())
|
||||
}
|
||||
fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::galadriel::{Client, self};
|
||||
///
|
||||
/// // Initialize the Galadriel client
|
||||
/// let galadriel = Client::new("your-galadriel-api-key", None);
|
||||
///
|
||||
/// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::galadriel::{Client, self};
|
||||
///
|
||||
/// // Initialize the Galadriel client
|
||||
/// let galadriel = Client::new("your-galadriel-api-key", None);
|
||||
///
|
||||
/// let agent = galadriel.agent(galadriel::GPT_4)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an extractor builder with the given completion model.
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiErrorResponse {
|
||||
message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {} Total tokens: {}",
|
||||
self.prompt_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Galadriel Completion API
|
||||
// ================================================================
|
||||
/// `o1-preview` completion model
|
||||
pub const O1_PREVIEW: &str = "o1-preview";
|
||||
/// `o1-preview-2024-09-12` completion model
|
||||
pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
|
||||
/// `o1-mini completion model
|
||||
pub const O1_MINI: &str = "o1-mini";
|
||||
/// `o1-mini-2024-09-12` completion model
|
||||
pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
|
||||
/// `gpt-4o` completion model
|
||||
pub const GPT_4O: &str = "gpt-4o";
|
||||
/// `gpt-4o-2024-05-13` completion model
|
||||
pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
|
||||
/// `gpt-4-turbo` completion model
|
||||
pub const GPT_4_TURBO: &str = "gpt-4-turbo";
|
||||
/// `gpt-4-turbo-2024-04-09` completion model
|
||||
pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
|
||||
/// `gpt-4-turbo-preview` completion model
|
||||
pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
|
||||
/// `gpt-4-0125-preview` completion model
|
||||
pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
|
||||
/// `gpt-4-1106-preview` completion model
|
||||
pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
|
||||
/// `gpt-4-vision-preview` completion model
|
||||
pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
|
||||
/// `gpt-4-1106-vision-preview` completion model
|
||||
pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
|
||||
/// `gpt-4` completion model
|
||||
pub const GPT_4: &str = "gpt-4";
|
||||
/// `gpt-4-0613` completion model
|
||||
pub const GPT_4_0613: &str = "gpt-4-0613";
|
||||
/// `gpt-4-32k` completion model
|
||||
pub const GPT_4_32K: &str = "gpt-4-32k";
|
||||
/// `gpt-4-32k-0613` completion model
|
||||
pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
|
||||
/// `gpt-3.5-turbo` completion model
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
/// `gpt-3.5-turbo-0125` completion model
|
||||
pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
|
||||
/// `gpt-3.5-turbo-1106` completion model
|
||||
pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
|
||||
/// `gpt-3.5-turbo-instruct` completion model
|
||||
pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl From<ApiErrorResponse> for CompletionError {
|
||||
fn from(err: ApiErrorResponse) -> Self {
|
||||
CompletionError::ProviderError(err.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
|
||||
match value.choices.as_slice() {
|
||||
[Choice {
|
||||
message:
|
||||
Message {
|
||||
tool_calls: Some(calls),
|
||||
..
|
||||
},
|
||||
..
|
||||
}, ..] => {
|
||||
let call = calls.first().ok_or(CompletionError::ResponseError(
|
||||
"Tool selection is empty".into(),
|
||||
))?;
|
||||
|
||||
Ok(completion::CompletionResponse {
|
||||
choice: completion::ModelChoice::ToolCall(
|
||||
call.function.name.clone(),
|
||||
"".to_owned(),
|
||||
serde_json::from_str(&call.function.arguments)?,
|
||||
),
|
||||
raw_response: value,
|
||||
})
|
||||
}
|
||||
[Choice {
|
||||
message:
|
||||
Message {
|
||||
content: Some(content),
|
||||
..
|
||||
},
|
||||
..
|
||||
}, ..] => Ok(completion::CompletionResponse {
|
||||
choice: completion::ModelChoice::Message(content.to_string()),
|
||||
raw_response: value,
|
||||
}),
|
||||
_ => Err(CompletionError::ResponseError(
|
||||
"Response did not contain a message or tool call".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub logprobs: Option<serde_json::Value>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub r#type: String,
|
||||
pub function: completion::ToolDefinition,
|
||||
}
|
||||
|
||||
impl From<completion::ToolDefinition> for ToolDefinition {
|
||||
fn from(tool: completion::ToolDefinition) -> Self {
|
||||
Self {
|
||||
r#type: "function".into(),
|
||||
function: tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
client: Client,
|
||||
/// Name of the model (e.g.: gpt-3.5-turbo-1106)
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
mut completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history = if let Some(preamble) = &completion_request.preamble {
|
||||
vec![completion::Message {
|
||||
role: "system".into(),
|
||||
content: preamble.clone(),
|
||||
}]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Extend existing chat history
|
||||
full_history.append(&mut completion_request.chat_history);
|
||||
|
||||
// Add context documents to chat history
|
||||
let prompt_with_context = completion_request.prompt_with_context();
|
||||
|
||||
// Add context documents to chat history
|
||||
full_history.push(completion::Message {
|
||||
role: "user".into(),
|
||||
content: prompt_with_context,
|
||||
});
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
"model": self.model,
|
||||
"messages": full_history,
|
||||
"temperature": completion_request.temperature,
|
||||
})
|
||||
} else {
|
||||
json!({
|
||||
"model": self.model,
|
||||
"messages": full_history,
|
||||
"temperature": completion_request.temperature,
|
||||
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
|
||||
"tool_choice": "auto",
|
||||
})
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(
|
||||
&if let Some(params) = completion_request.additional_params {
|
||||
json_utils::merge(request, params)
|
||||
} else {
|
||||
request
|
||||
},
|
||||
)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response.json::<ApiResponse<CompletionResponse>>().await? {
|
||||
ApiResponse::Ok(response) => {
|
||||
tracing::info!(target: "rig",
|
||||
"Galadriel completion token usage: {:?}",
|
||||
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
|
||||
);
|
||||
response.try_into()
|
||||
}
|
||||
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
|
||||
}
|
||||
} else {
|
||||
Err(CompletionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -46,6 +46,7 @@
|
|||
pub mod anthropic;
|
||||
pub mod cohere;
|
||||
pub mod deepseek;
|
||||
pub mod galadriel;
|
||||
pub mod gemini;
|
||||
pub mod hyperbolic;
|
||||
pub mod openai;
|
||||
|
|
Loading…
Reference in New Issue