feat: galadriel api integration (redux) (#265)

* feat: galadriel api integration

* refactor: add cfg-attr for worker feature
This commit is contained in:
Joshua Mo 2025-02-03 23:20:03 +00:00 committed by GitHub
parent f869ce0a38
commit 3b6692b737
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 423 additions and 0 deletions

View File

@ -0,0 +1,23 @@
use rig::{completion::Prompt, providers};
use std::env;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create Galadriel client
let client = providers::galadriel::Client::new(
&env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set"),
env::var("GALADRIEL_FINE_TUNE_API_KEY").ok().as_deref(),
);
// Create agent with a single context prompt
let comedian_agent = client
.agent("gpt-4o")
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();
// Prompt the agent and print the response
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);
Ok(())
}

View File

@ -0,0 +1,399 @@
//! Galadriel API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::galadriel;
//!
//! let client = galadriel::Client::new("YOUR_API_KEY", None);
//! // to use a fine-tuned model
//! // let client = galadriel::Client::new("YOUR_API_KEY", "FINE_TUNE_API_KEY");
//!
//! let gpt4o = client.completion_model(galadriel::GPT_4O);
//! ```
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
// ================================================================
// Main Galadriel Client
// ================================================================
const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new Galadriel client with the given API key and optional fine-tune API key.
pub fn new(api_key: &str, fine_tune_api_key: Option<&str>) -> Self {
Self::from_url_with_optional_key(api_key, GALADRIEL_API_BASE_URL, fine_tune_api_key)
}
/// Create a new Galadriel client with the given API key, base API URL, and optional fine-tune API key.
pub fn from_url(api_key: &str, base_url: &str, fine_tune_api_key: Option<&str>) -> Self {
Self::from_url_with_optional_key(api_key, base_url, fine_tune_api_key)
}
pub fn from_url_with_optional_key(
api_key: &str,
base_url: &str,
fine_tune_api_key: Option<&str>,
) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
if let Some(key) = fine_tune_api_key {
headers.insert(
"Fine-Tune-Authorization",
format!("Bearer {}", key)
.parse()
.expect("Bearer token should parse"),
);
}
headers
})
.build()
.expect("Galadriel reqwest client should build"),
}
}
/// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable,
/// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable.
/// Panics if the `GALADRIEL_API_KEY` environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("GALADRIEL_API_KEY").expect("GALADRIEL_API_KEY not set");
let fine_tune_api_key = std::env::var("GALADRIEL_FINE_TUNE_API_KEY").ok();
Self::new(&api_key, fine_tune_api_key.as_deref())
}
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::galadriel::{Client, self};
///
/// // Initialize the Galadriel client
/// let galadriel = Client::new("your-galadriel-api-key", None);
///
/// let gpt4 = galadriel.completion_model(galadriel::GPT_4);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::galadriel::{Client, self};
///
/// // Initialize the Galadriel client
/// let galadriel = Client::new("your-galadriel-api-key", None);
///
/// let agent = galadriel.agent(galadriel::GPT_4)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}
// ================================================================
// Galadriel Completion API
// ================================================================
/// `o1-preview` completion model
pub const O1_PREVIEW: &str = "o1-preview";
/// `o1-preview-2024-09-12` completion model
pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
/// `o1-mini completion model
pub const O1_MINI: &str = "o1-mini";
/// `o1-mini-2024-09-12` completion model
pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
/// `gpt-4o` completion model
pub const GPT_4O: &str = "gpt-4o";
/// `gpt-4o-2024-05-13` completion model
pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
/// `gpt-4-turbo` completion model
pub const GPT_4_TURBO: &str = "gpt-4-turbo";
/// `gpt-4-turbo-2024-04-09` completion model
pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
/// `gpt-4-turbo-preview` completion model
pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
/// `gpt-4-0125-preview` completion model
pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
/// `gpt-4-1106-preview` completion model
pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
/// `gpt-4-vision-preview` completion model
pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
/// `gpt-4-1106-vision-preview` completion model
pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
/// `gpt-4` completion model
pub const GPT_4: &str = "gpt-4";
/// `gpt-4-0613` completion model
pub const GPT_4_0613: &str = "gpt-4-0613";
/// `gpt-4-32k` completion model
pub const GPT_4_32K: &str = "gpt-4-32k";
/// `gpt-4-32k-0613` completion model
pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
/// `gpt-3.5-turbo` completion model
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
/// `gpt-3.5-turbo-0125` completion model
pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
/// `gpt-3.5-turbo-1106` completion model
pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
/// `gpt-3.5-turbo-instruct` completion model
pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.message)
}
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(value: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
match value.choices.as_slice() {
[Choice {
message:
Message {
tool_calls: Some(calls),
..
},
..
}, ..] => {
let call = calls.first().ok_or(CompletionError::ResponseError(
"Tool selection is empty".into(),
))?;
Ok(completion::CompletionResponse {
choice: completion::ModelChoice::ToolCall(
call.function.name.clone(),
"".to_owned(),
serde_json::from_str(&call.function.arguments)?,
),
raw_response: value,
})
}
[Choice {
message:
Message {
content: Some(content),
..
},
..
}, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::Message(content.to_string()),
raw_response: value,
}),
_ => Err(CompletionError::ResponseError(
"Response did not contain a message or tool call".into(),
)),
}
}
}
#[derive(Debug, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}
#[derive(Debug, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Deserialize)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub r#type: String,
pub function: completion::ToolDefinition,
}
impl From<completion::ToolDefinition> for ToolDefinition {
fn from(tool: completion::ToolDefinition) -> Self {
Self {
r#type: "function".into(),
function: tool,
}
}
}
#[derive(Debug, Deserialize)]
pub struct Function {
pub name: String,
pub arguments: String,
}
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
/// Name of the model (e.g.: gpt-3.5-turbo-1106)
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
mut completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history = if let Some(preamble) = &completion_request.preamble {
vec![completion::Message {
role: "system".into(),
content: preamble.clone(),
}]
} else {
vec![]
};
// Extend existing chat history
full_history.append(&mut completion_request.chat_history);
// Add context documents to chat history
let prompt_with_context = completion_request.prompt_with_context();
// Add context documents to chat history
full_history.push(completion::Message {
role: "user".into(),
content: prompt_with_context,
});
let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};
let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"Galadriel completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -46,6 +46,7 @@
pub mod anthropic;
pub mod cohere;
pub mod deepseek;
pub mod galadriel;
pub mod gemini;
pub mod hyperbolic;
pub mod openai;