feat: Add image generation to all providers that support it (#357)

* feat: add transcription to all providers that support it

* feat: gemini streaming

* chore: fix clippy warnings

* feat: together streaming

* refactor: move transcription preable to constant

* chore: fix clippy

* fix: enable blob by default

* fix: lock methods and imports behind feature flag

* chore: fmt

* feat: add huggingface transcription, remove feature flag

* feat: streaming for most openai type providers

* feat: add streaming to remaining providers

* chore: formatting

* fix: gemini streaming wasn't added to mod

* chore: unused imports in example

* chore: unused import

* chore: run fmt

* feat: openai & azure image generation

* chore: reorganize a bit

* refactor: simplify oai-compatible streaming

* feat: huggingface image generation

* chore: run fmt & clippy

* feat: hyperbolic image generation

* refactor: break up openai

* fix: lock changes behind feature flag

* chore: remove unused import

* refactor: change size to width/height, update examples

* fix: error with examples

* chore: remove unused imports

* fix: feature in lib not gated behind flag

* fix: used wrong flag

* fix: import not behind flag

* fix: more flag issues

---------

Co-authored-by: yavens <179155341+yavens@users.github.noreply.github.com>
This commit is contained in:
yavens 2025-03-19 16:32:17 -04:00 committed by GitHub
parent ab8cb89948
commit 70855257c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 2333 additions and 1619 deletions

View File

@ -48,6 +48,7 @@ serde_path_to_error = "0.1.16"
[features]
default = ["reqwest/default"]
all = ["derive", "pdf", "rayon"]
image = []
derive = ["dep:rig-derive"]
pdf = ["dep:lopdf"]
epub = ["dep:epub", "dep:quick-xml"]

View File

@ -1,4 +1,4 @@
use rig::providers::openai::Client;
use rig::providers::openai::client::Client;
use schemars::JsonSchema;
use std::env;

View File

@ -1,6 +1,7 @@
use std::env;
use rig::{completion::Prompt, providers::openai::Client};
use rig::completion::Prompt;
use rig::providers::openai::client::Client;
use schemars::JsonSchema;
#[derive(serde::Deserialize, JsonSchema, serde::Serialize, Debug)]

View File

@ -1,6 +1,6 @@
use std::env;
use rig::providers::openai::Client;
use rig::providers::openai::client::Client;
use schemars::JsonSchema;

View File

@ -1,10 +1,10 @@
use std::env;
use rig::pipeline::agent_ops::extract;
use rig::providers::openai::client::Client;
use rig::{
parallel,
pipeline::{self, passthrough, Op},
providers::openai::Client,
};
use schemars::JsonSchema;

View File

@ -1,9 +1,7 @@
use std::env;
use rig::{
pipeline::{self, Op},
providers::openai::Client,
};
use rig::pipeline::{self, Op};
use rig::providers::openai::client::Client;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {

View File

@ -1,9 +1,7 @@
use std::env;
use rig::{
pipeline::{self, Op, TryOp},
providers::openai::Client,
};
use rig::pipeline::{self, Op, TryOp};
use rig::providers::openai::client::Client;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {

View File

@ -0,0 +1,38 @@
use rig::image_generation::ImageGenerationModel;
use rig::providers::huggingface;
use std::env::args;
use std::fs::File;
use std::io::Write;
use std::path::Path;
const DEFAULT_PATH: &str = "./output.png";
#[tokio::main]
async fn main() {
let arguments: Vec<String> = args().collect();
let path = if arguments.len() > 1 {
arguments[1].clone()
} else {
DEFAULT_PATH.to_string()
};
let path = Path::new(&path);
let mut file = File::create_new(path).expect("Failed to create file");
let huggingface = huggingface::Client::from_env();
let dalle = huggingface.image_generation_model(huggingface::STABLE_DIFFUSION_3);
let response = dalle
.image_generation_request()
.prompt("A castle sitting upon a large mountain, overlooking the water.")
.width(1024)
.height(1024)
.send()
.await
.expect("Failed to generate image");
let _ = file.write(&response.image);
}

View File

@ -0,0 +1,37 @@
use rig::image_generation::ImageGenerationModel;
use rig::providers::hyperbolic;
use std::env::args;
use std::fs::File;
use std::io::Write;
use std::path::Path;
const DEFAULT_PATH: &str = "./output.png";
#[tokio::main]
async fn main() {
let arguments: Vec<String> = args().collect();
let path = if arguments.len() > 1 {
arguments[1].clone()
} else {
DEFAULT_PATH.to_string()
};
let path = Path::new(&path);
let mut file = File::create_new(path).expect("Failed to create file");
let hyperbolic = hyperbolic::Client::from_env();
let stable_diffusion = hyperbolic.image_generation_model(hyperbolic::SDXL_TURBO);
let response = stable_diffusion
.image_generation_request()
.prompt("A castle sitting upon a large mountain, overlooking the water.")
.width(1024)
.height(1024)
.send()
.await
.expect("Failed to generate image");
let _ = file.write(&response.image);
}

View File

@ -0,0 +1,37 @@
use rig::image_generation::ImageGenerationModel;
use rig::providers::openai;
use std::env::args;
use std::fs::File;
use std::io::Write;
use std::path::Path;
const DEFAULT_PATH: &str = "./output.png";
#[tokio::main]
async fn main() {
let arguments: Vec<String> = args().collect();
let path = if arguments.len() > 1 {
arguments[1].clone()
} else {
DEFAULT_PATH.to_string()
};
let path = Path::new(&path);
let mut file = File::create_new(path).expect("Failed to create file");
let openai = openai::Client::from_env();
let dalle = openai.image_generation_model(openai::DALL_E_2);
let response = dalle
.image_generation_request()
.prompt("A castle sitting upon a large mountain, overlooking the water.")
.width(1024)
.height(1024)
.send()
.await
.expect("Failed to generate image");
let _ = file.write(&response.image);
}

View File

@ -1,11 +1,9 @@
use std::{env, vec};
use rig::providers::openai::client::Client;
use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
Embed,
completion::Prompt, embeddings::EmbeddingsBuilder, providers::openai::TEXT_EMBEDDING_ADA_002,
vector_store::in_memory_store::InMemoryVectorStore, Embed,
};
use serde::Serialize;

View File

@ -1,8 +1,9 @@
use std::env;
use rig::providers::openai::client::Client;
use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
providers::openai::TEXT_EMBEDDING_ADA_002,
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
Embed,
};

View File

@ -0,0 +1,128 @@
use serde_json::Value;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ImageGenerationError {
/// Http error (e.g.: connection error, timeout, etc.)
#[error("HttpError: {0}")]
HttpError(#[from] reqwest::Error),
/// Json error (e.g.: serialization, deserialization)
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
/// Error building the transcription request
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
/// Error parsing the transcription response
#[error("ResponseError: {0}")]
ResponseError(String),
/// Error returned by the transcription model provider
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait ImageGeneration<M: ImageGenerationModel> {
/// Generates a transcription request builder for the given `file`.
/// This function is meant to be called by the user to further customize the
/// request at transcription time before sending it.
///
/// ❗IMPORTANT: The type that implements this trait might have already
/// populated fields in the builder (the exact fields depend on the type).
/// For fields that have already been set by the model, calling the corresponding
/// method on the builder will overwrite the value set by the model.
fn image_generation(
&self,
prompt: &str,
size: &(u32, u32),
) -> impl std::future::Future<
Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
> + Send;
}
pub struct ImageGenerationResponse<T> {
pub image: Vec<u8>,
pub response: T,
}
pub trait ImageGenerationModel: Clone + Send + Sync {
type Response: Send + Sync;
fn image_generation(
&self,
request: ImageGenerationRequest,
) -> impl std::future::Future<
Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
> + Send;
fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
ImageGenerationRequestBuilder::new(self.clone())
}
}
pub struct ImageGenerationRequest {
pub prompt: String,
pub width: u32,
pub height: u32,
pub additional_params: Option<Value>,
}
pub struct ImageGenerationRequestBuilder<M: ImageGenerationModel> {
model: M,
prompt: String,
width: u32,
height: u32,
additional_params: Option<Value>,
}
impl<M: ImageGenerationModel> ImageGenerationRequestBuilder<M> {
pub fn new(model: M) -> Self {
Self {
model,
prompt: "".to_string(),
height: 256,
width: 256,
additional_params: None,
}
}
/// Sets the prompt for the image generation request
pub fn prompt(mut self, prompt: &str) -> Self {
self.prompt = prompt.to_string();
self
}
/// The width of the generated image
pub fn width(mut self, width: u32) -> Self {
self.width = width;
self
}
/// The height of the generated image
pub fn height(mut self, height: u32) -> Self {
self.height = height;
self
}
/// Adds additional parameters to the image generation request.
pub fn additional_params(mut self, params: Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn build(self) -> ImageGenerationRequest {
ImageGenerationRequest {
prompt: self.prompt,
width: self.width,
height: self.height,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
let model = self.model.clone();
model.image_generation(self.build()).await
}
}

View File

@ -85,6 +85,8 @@ pub mod cli_chatbot;
pub mod completion;
pub mod embeddings;
pub mod extractor;
#[cfg(feature = "image")]
pub mod image_generation;
pub(crate) mod json_utils;
pub mod loaders;
pub mod one_or_many;

View File

@ -10,6 +10,12 @@
//! ```
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
#[cfg(feature = "image")]
use super::openai::ImageGenerationResponse;
#[cfg(feature = "image")]
use crate::image_generation::{self, ImageGenerationError, ImageGenerationRequest};
use crate::json_utils::merge;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::{
@ -157,6 +163,16 @@ impl Client {
self.http_client.post(url)
}
#[cfg(feature = "image")]
fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
let url = format!(
"{}/openai/deployments/{}/images/generations?api-version={}",
self.azure_endpoint, deployment_id, self.api_version
)
.replace("//", "/");
self.http_client.post(url)
}
/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
@ -642,6 +658,55 @@ impl transcription::TranscriptionModel for TranscriptionModel {
}
}
// ================================================================
// Azure OpenAI Image Generation API
// ================================================================
#[cfg(feature = "image")]
#[derive(Clone)]
pub struct ImageGenerationModel {
client: Client,
pub model: String,
}
#[cfg(feature = "image")]
impl image_generation::ImageGenerationModel for ImageGenerationModel {
type Response = ImageGenerationResponse;
async fn image_generation(
&self,
generation_request: ImageGenerationRequest,
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
{
let request = json!({
"model": self.model,
"prompt": generation_request.prompt,
"size": format!("{}x{}", generation_request.width, generation_request.height),
"response_format": "b64_json"
});
let response = self
.client
.post_image_generation(&self.model)
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(ImageGenerationError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
let t = response.text().await?;
match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
}
}
}
#[cfg(test)]
mod azure_tests {
use super::*;

View File

@ -2,7 +2,12 @@ use std::fmt::Display;
use super::completion::CompletionModel;
use crate::agent::AgentBuilder;
#[cfg(feature = "image")]
use crate::image_generation::ImageGenerationError;
#[cfg(feature = "image")]
use crate::providers::huggingface::image_generation::ImageGenerationModel;
use crate::providers::huggingface::transcription::TranscriptionModel;
use crate::transcription::TranscriptionError;
// ================================================================
// Main Huggingface Client
@ -36,10 +41,27 @@ impl SubProvider {
/// Get the transcription endpoint for the SubProvider
/// Required because Huggingface Inference requires the model
/// in the url and in the request body.
pub fn transcription_endpoint(&self, model: &str) -> String {
pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
match self {
SubProvider::HFInference => format!("hf-inference/models/{}", model),
_ => panic!("transcription endpoint is not supported yet for {}", self),
SubProvider::HFInference => Ok(format!("/{}", model)),
_ => Err(TranscriptionError::ProviderError(format!(
"transcription endpoint is not supported yet for {}",
self
))),
}
}
/// Get the image generation endpoint for the SubProvider
/// Required because Huggingface Inference requires the model
/// in the url and in the request body.
#[cfg(feature = "image")]
pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
match self {
SubProvider::HFInference => Ok(format!("/{}", model)),
_ => Err(ImageGenerationError::ProviderError(format!(
"image generation endpoint is not supported yet for {}",
self
))),
}
}
@ -124,7 +146,9 @@ pub struct Client {
impl Client {
/// Create a new Huggingface client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, HUGGINGFACE_API_BASE_URL, SubProvider::HFInference)
let base_url =
format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
Self::from_url(api_key, &base_url, SubProvider::HFInference)
}
/// Create a new Client with the given API key and base API URL.
@ -193,10 +217,27 @@ impl Client {
///
/// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
/// ```
///
pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
TranscriptionModel::new(self.clone(), model)
}
/// Create a new image generation model with the given name
///
/// # Example
/// ```
/// use rig::providers::huggingface::{Client, self}
///
/// // Initialize the Huggingface client
/// let client = Client::new("your-huggingface-api-key");
///
/// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
/// ```
#[cfg(feature = "image")]
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
ImageGenerationModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example

View File

@ -0,0 +1,78 @@
use super::Client;
use crate::image_generation;
use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
use serde_json::json;
pub const FLUX_1: &str = "black-forest-labs/FLUX.1-dev";
pub const KOLORS: &str = "Kwai-Kolors/Kolors";
pub const STABLE_DIFFUSION_3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
#[derive(Debug)]
pub struct ImageGenerationResponse {
data: Vec<u8>,
}
impl TryFrom<ImageGenerationResponse>
for image_generation::ImageGenerationResponse<ImageGenerationResponse>
{
type Error = ImageGenerationError;
fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
Ok(image_generation::ImageGenerationResponse {
image: value.data.clone(),
response: value,
})
}
}
#[derive(Clone)]
pub struct ImageGenerationModel {
client: Client,
pub model: String,
}
impl ImageGenerationModel {
pub fn new(client: Client, model: &str) -> Self {
ImageGenerationModel {
client,
model: model.to_string(),
}
}
}
impl image_generation::ImageGenerationModel for ImageGenerationModel {
type Response = ImageGenerationResponse;
async fn image_generation(
&self,
request: ImageGenerationRequest,
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
{
let request = json!({
"inputs": request.prompt,
"parameters": {
"width": request.width,
"height": request.height
}
});
let route = self
.client
.sub_provider
.image_generation_endpoint(&self.model)?;
let response = self.client.post(&route).json(&request).send().await?;
if !response.status().is_success() {
return Err(ImageGenerationError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
let data = response.bytes().await?.to_vec();
ImageGenerationResponse { data }.try_into()
}
}

View File

@ -12,6 +12,9 @@
pub mod client;
pub mod completion;
#[cfg(feature = "image")]
pub mod image_generation;
pub mod streaming;
pub mod transcription;
@ -20,4 +23,7 @@ pub use completion::{
GEMMA_2, META_LLAMA_3_1, PHI_4, QWEN2_5, QWEN2_5_CODER, QWEN2_VL, QWEN_QVQ_PREVIEW,
SMALLTHINKER_PREVIEW,
};
#[cfg(feature = "image")]
pub use image_generation::{FLUX_1, KOLORS, STABLE_DIFFUSION_3};
pub use transcription::{WHISPER_LARGE_V3, WHISPER_LARGE_V3_TURBO, WHISPER_SMALL};

View File

@ -59,7 +59,10 @@ impl transcription::TranscriptionModel for TranscriptionModel {
"inputs": data
});
let route = self.client.sub_provider.transcription_endpoint(&self.model);
let route = self
.client
.sub_provider
.transcription_endpoint(&self.model)?;
let response = self.client.post(&route).json(&request).send().await?;
if response.status().is_success() {

View File

@ -10,6 +10,9 @@
//! ```
use super::openai::{send_compatible_streaming_request, AssistantContent};
#[cfg(feature = "image")]
use crate::image_generation::{self, ImageGenerationError, ImageGenerationRequest};
use crate::json_utils::merge_inplace;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use crate::{
@ -20,6 +23,11 @@ use crate::{
providers::openai::Message,
OneOrMany,
};
#[cfg(feature = "image")]
use base64::prelude::BASE64_STANDARD;
#[cfg(feature = "image")]
use base64::Engine;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
@ -88,6 +96,22 @@ impl Client {
CompletionModel::new(self.clone(), model)
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::hyperbolic::{Client, self};
///
/// // Initialize the Hyperbolic client
/// let hyperbolic = Client::new("your-hyperbolic-api-key");
///
/// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
/// ```
#[cfg(feature = "image")]
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
ImageGenerationModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
@ -369,3 +393,106 @@ impl StreamingCompletionModel for CompletionModel {
send_compatible_streaming_request(builder).await
}
}
// =======================================
// Hyperbolic Image Generation API
// =======================================
pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
pub const SD2: &str = "SD2";
pub const SD1_5: &str = "SD1.5";
pub const SSD: &str = "SSD";
pub const SDXL_TURBO: &str = "SDXL-turbo";
pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
#[cfg(feature = "image")]
#[derive(Clone)]
pub struct ImageGenerationModel {
client: Client,
pub model: String,
}
#[cfg(feature = "image")]
impl ImageGenerationModel {
fn new(client: Client, model: &str) -> ImageGenerationModel {
Self {
client,
model: model.to_string(),
}
}
}
#[cfg(feature = "image")]
#[derive(Clone, Deserialize)]
pub struct Image {
image: String,
}
#[cfg(feature = "image")]
#[derive(Clone, Deserialize)]
pub struct ImageGenerationResponse {
images: Vec<Image>,
}
#[cfg(feature = "image")]
impl TryFrom<ImageGenerationResponse>
for image_generation::ImageGenerationResponse<ImageGenerationResponse>
{
type Error = ImageGenerationError;
fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
let data = BASE64_STANDARD
.decode(&value.images[0].image)
.expect("Could not decode image.");
Ok(Self {
image: data,
response: value,
})
}
}
#[cfg(feature = "image")]
impl image_generation::ImageGenerationModel for ImageGenerationModel {
type Response = ImageGenerationResponse;
async fn image_generation(
&self,
generation_request: ImageGenerationRequest,
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
{
let mut request = json!({
"model_name": self.model,
"prompt": generation_request.prompt,
"height": generation_request.height,
"width": generation_request.width,
});
if let Some(params) = generation_request.additional_params {
merge_inplace(&mut request, params);
}
let response = self
.client
.post("/image/generation")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(ImageGenerationError::ProviderError(format!(
"{}: {}",
response.status().as_str(),
response.text().await?
)));
}
match response
.json::<ApiResponse<ImageGenerationResponse>>()
.await?
{
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,515 @@
use super::completion::CompletionModel;
use super::embedding::{
EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
};
#[cfg(feature = "image")]
use super::image_generation::ImageGenerationModel;
use super::transcription::TranscriptionModel;
use crate::agent::AgentBuilder;
use crate::embeddings::EmbeddingsBuilder;
use crate::extractor::ExtractorBuilder;
use crate::Embed;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
// ================================================================
// Main OpenAI Client
// ================================================================
const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
/// Create a new OpenAI client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, OPENAI_API_BASE_URL)
}
/// Create a new OpenAI client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("OpenAI reqwest client should build"),
}
}
/// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
Self::new(&api_key)
}
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
let ndims = match model {
TEXT_EMBEDDING_3_LARGE => 3072,
TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, ndims)
}
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let embedding_model = openai.embedding_model("model-unknown-to-rig", 3072);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}
/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model))
}
/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let gpt4 = openai.completion_model(openai::GPT_4);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let agent = openai.agent(openai::GPT_4)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
/// Create a transcription model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let gpt4 = openai.transcription_model(openai::WHISPER_1);
/// ```
pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
TranscriptionModel::new(self.clone(), model)
}
/// Create an image generation model with the given name.
///
/// # Example
/// ```
/// use rig::providers::openai::{Client, self};
///
/// // Initialize the OpenAI client
/// let openai = Client::new("your-open-ai-api-key");
///
/// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
/// ```
#[cfg(feature = "image")]
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
ImageGenerationModel::new(self.clone(), model)
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {} Total tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub(crate) message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(crate) enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[cfg(test)]
mod tests {
use crate::message::ImageDetail;
use crate::providers::openai::{
AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
};
use crate::{message, OneOrMany};
use serde_path_to_error::deserialize;
#[test]
fn test_deserialize_message() {
let assistant_message_json = r#"
{
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?"
}
"#;
let assistant_message_json2 = r#"
{
"role": "assistant",
"content": [
{
"type": "text",
"text": "\n\nHello there, how may I assist you today?"
}
],
"tool_calls": null
}
"#;
let assistant_message_json3 = r#"
{
"role": "assistant",
"tool_calls": [
{
"id": "call_h89ipqYUjEpCPI6SxspMnoUU",
"type": "function",
"function": {
"name": "subtract",
"arguments": "{\"x\": 2, \"y\": 5}"
}
}
],
"content": null,
"refusal": null
}
"#;
let user_message_json = r#"
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
},
{
"type": "audio",
"input_audio": {
"data": "...",
"format": "mp3"
}
}
]
}
"#;
let assistant_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
let assistant_message2: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
let assistant_message3: Message = {
let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
&mut serde_json::Deserializer::from_str(assistant_message_json3);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
let user_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(user_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!(
"Deserialization error at {} ({}:{}): {}",
err.path(),
err.inner().line(),
err.inner().column(),
err
);
})
};
match assistant_message {
Message::Assistant { content, .. } => {
assert_eq!(
content[0],
AssistantContent::Text {
text: "\n\nHello there, how may I assist you today?".to_string()
}
);
}
_ => panic!("Expected assistant message"),
}
match assistant_message2 {
Message::Assistant {
content,
tool_calls,
..
} => {
assert_eq!(
content[0],
AssistantContent::Text {
text: "\n\nHello there, how may I assist you today?".to_string()
}
);
assert_eq!(tool_calls, vec![]);
}
_ => panic!("Expected assistant message"),
}
match assistant_message3 {
Message::Assistant {
content,
tool_calls,
refusal,
..
} => {
assert!(content.is_empty());
assert!(refusal.is_none());
assert_eq!(
tool_calls[0],
ToolCall {
id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
r#type: ToolType::Function,
function: Function {
name: "subtract".to_string(),
arguments: serde_json::json!({"x": 2, "y": 5}),
},
}
);
}
_ => panic!("Expected assistant message"),
}
match user_message {
Message::User { content, .. } => {
let (first, second) = {
let mut iter = content.into_iter();
(iter.next().unwrap(), iter.next().unwrap())
};
assert_eq!(
first,
UserContent::Text {
text: "What's in this image?".to_string()
}
);
assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
}
_ => panic!("Expected user message"),
}
}
#[test]
fn test_message_to_message_conversion() {
let user_message = message::Message::User {
content: OneOrMany::one(message::UserContent::text("Hello")),
};
let assistant_message = message::Message::Assistant {
content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
};
let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
let converted_assistant_message: Vec<Message> =
assistant_message.clone().try_into().unwrap();
match converted_user_message[0].clone() {
Message::User { content, .. } => {
assert_eq!(
content.first(),
UserContent::Text {
text: "Hello".to_string()
}
);
}
_ => panic!("Expected user message"),
}
match converted_assistant_message[0].clone() {
Message::Assistant { content, .. } => {
assert_eq!(
content[0].clone(),
AssistantContent::Text {
text: "Hi there!".to_string()
}
);
}
_ => panic!("Expected assistant message"),
}
let original_user_message: message::Message =
converted_user_message[0].clone().try_into().unwrap();
let original_assistant_message: message::Message =
converted_assistant_message[0].clone().try_into().unwrap();
assert_eq!(original_user_message, user_message);
assert_eq!(original_assistant_message, assistant_message);
}
#[test]
fn test_message_from_message_conversion() {
let user_message = Message::User {
content: OneOrMany::one(UserContent::Text {
text: "Hello".to_string(),
}),
name: None,
};
let assistant_message = Message::Assistant {
content: vec![AssistantContent::Text {
text: "Hi there!".to_string(),
}],
refusal: None,
audio: None,
name: None,
tool_calls: vec![],
};
let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
let converted_assistant_message: message::Message =
assistant_message.clone().try_into().unwrap();
match converted_user_message.clone() {
message::Message::User { content } => {
assert_eq!(content.first(), message::UserContent::text("Hello"));
}
_ => panic!("Expected user message"),
}
match converted_assistant_message.clone() {
message::Message::Assistant { content } => {
assert_eq!(
content.first(),
message::AssistantContent::text("Hi there!")
);
}
_ => panic!("Expected assistant message"),
}
let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
let original_assistant_message: Vec<Message> =
converted_assistant_message.try_into().unwrap();
assert_eq!(original_user_message[0], user_message);
assert_eq!(original_assistant_message[0], assistant_message);
}
}

View File

@ -0,0 +1,704 @@
// ================================================================
// OpenAI Completion API
// ================================================================
use super::{ApiErrorResponse, ApiResponse, Client, Usage};
use crate::completion::{CompletionError, CompletionRequest};
use crate::message::{AudioMediaType, ImageDetail};
use crate::one_or_many::string_or_one_or_many;
use crate::{completion, json_utils, message, OneOrMany};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::convert::Infallible;
use std::str::FromStr;
/// `o3-mini` completion model
pub const O3_MINI: &str = "o3-mini";
/// `o3-mini-2025-01-31` completion model
pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
/// 'o1' completion model
pub const O1: &str = "o1";
/// `o1-2024-12-17` completion model
pub const O1_2024_12_17: &str = "o1-2024-12-17";
/// `o1-preview` completion model
pub const O1_PREVIEW: &str = "o1-preview";
/// `o1-preview-2024-09-12` completion model
pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
/// `o1-mini completion model
pub const O1_MINI: &str = "o1-mini";
/// `o1-mini-2024-09-12` completion model
pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
/// `gpt-4.5-preview` completion model
pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
/// `gpt-4.5-preview-2025-02-27` completion model
pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
/// `gpt-4o` completion model
pub const GPT_4O: &str = "gpt-4o";
/// `gpt-4o-mini` completion model
pub const GPT_4O_MINI: &str = "gpt-4o-mini";
/// `gpt-4o-2024-05-13` completion model
pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
/// `gpt-4-turbo` completion model
pub const GPT_4_TURBO: &str = "gpt-4-turbo";
/// `gpt-4-turbo-2024-04-09` completion model
pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
/// `gpt-4-turbo-preview` completion model
pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
/// `gpt-4-0125-preview` completion model
pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
/// `gpt-4-1106-preview` completion model
pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
/// `gpt-4-vision-preview` completion model
pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
/// `gpt-4-1106-vision-preview` completion model
pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
/// `gpt-4` completion model
pub const GPT_4: &str = "gpt-4";
/// `gpt-4-0613` completion model
pub const GPT_4_0613: &str = "gpt-4-0613";
/// `gpt-4-32k` completion model
pub const GPT_4_32K: &str = "gpt-4-32k";
/// `gpt-4-32k-0613` completion model
pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
/// `gpt-3.5-turbo` completion model
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
/// `gpt-3.5-turbo-0125` completion model
pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
/// `gpt-3.5-turbo-1106` completion model
pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
/// `gpt-3.5-turbo-instruct` completion model
pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: Option<String>,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
impl From<ApiErrorResponse> for CompletionError {
fn from(err: ApiErrorResponse) -> Self {
CompletionError::ProviderError(err.message)
}
}
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let choice = response.choices.first().ok_or_else(|| {
CompletionError::ResponseError("Response contained no choices".to_owned())
})?;
let content = match &choice.message {
Message::Assistant {
content,
tool_calls,
..
} => {
let mut content = content
.iter()
.filter_map(|c| {
let s = match c {
AssistantContent::Text { text } => text,
AssistantContent::Refusal { refusal } => refusal,
};
if s.is_empty() {
None
} else {
Some(completion::AssistantContent::text(s))
}
})
.collect::<Vec<_>>();
content.extend(
tool_calls
.iter()
.map(|call| {
completion::AssistantContent::tool_call(
&call.id,
&call.function.name,
call.function.arguments.clone(),
)
})
.collect::<Vec<_>>(),
);
Ok(content)
}
_ => Err(CompletionError::ResponseError(
"Response did not contain a valid message or tool call".into(),
)),
}?;
let choice = OneOrMany::many(content).map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?;
Ok(completion::CompletionResponse {
choice,
raw_response: response,
})
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
#[serde(deserialize_with = "string_or_one_or_many")]
content: OneOrMany<SystemContent>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
#[serde(deserialize_with = "string_or_one_or_many")]
content: OneOrMany<UserContent>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
#[serde(default, deserialize_with = "json_utils::string_or_vec")]
content: Vec<AssistantContent>,
#[serde(skip_serializing_if = "Option::is_none")]
refusal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
audio: Option<AudioAssistant>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(
default,
deserialize_with = "json_utils::null_or_vec",
skip_serializing_if = "Vec::is_empty"
)]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "tool")]
ToolResult {
tool_call_id: String,
content: OneOrMany<ToolResultContent>,
},
}
impl Message {
pub fn system(content: &str) -> Self {
Message::System {
content: OneOrMany::one(content.to_owned().into()),
name: None,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct AudioAssistant {
id: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SystemContent {
#[serde(default)]
r#type: SystemContentType,
text: String,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum SystemContentType {
#[default]
Text,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum AssistantContent {
Text { text: String },
Refusal { refusal: String },
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum UserContent {
Text {
text: String,
},
#[serde(rename = "image_url")]
Image {
image_url: ImageUrl,
},
Audio {
input_audio: InputAudio,
},
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ImageUrl {
pub url: String,
#[serde(default)]
pub detail: ImageDetail,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct InputAudio {
pub data: String,
pub format: AudioMediaType,
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolResultContent {
#[serde(default)]
r#type: ToolResultContentType,
text: String,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolResultContentType {
#[default]
Text,
}
impl FromStr for ToolResultContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(s.to_owned().into())
}
}
impl From<String> for ToolResultContent {
fn from(s: String) -> Self {
ToolResultContent {
r#type: ToolResultContentType::default(),
text: s,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolCall {
pub id: String,
#[serde(default)]
pub r#type: ToolType,
pub function: Function,
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
#[default]
Function,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolDefinition {
pub r#type: String,
pub function: completion::ToolDefinition,
}
impl From<completion::ToolDefinition> for ToolDefinition {
fn from(tool: completion::ToolDefinition) -> Self {
Self {
r#type: "function".into(),
function: tool,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct Function {
pub name: String,
#[serde(with = "json_utils::stringified_json")]
pub arguments: serde_json::Value,
}
impl TryFrom<message::Message> for Vec<Message> {
type Error = message::MessageError;
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => {
let (tool_results, other_content): (Vec<_>, Vec<_>) = content
.into_iter()
.partition(|content| matches!(content, message::UserContent::ToolResult(_)));
// If there are messages with both tool results and user content, openai will only
// handle tool results. It's unlikely that there will be both.
if !tool_results.is_empty() {
tool_results
.into_iter()
.map(|content| match content {
message::UserContent::ToolResult(message::ToolResult {
id,
content,
}) => Ok::<_, message::MessageError>(Message::ToolResult {
tool_call_id: id,
content: content.try_map(|content| match content {
message::ToolResultContent::Text(message::Text { text }) => {
Ok(text.into())
}
_ => Err(message::MessageError::ConversionError(
"Tool result content does not support non-text".into(),
)),
})?,
}),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()
} else {
let other_content = OneOrMany::many(other_content).expect(
"There must be other content here if there were no tool result content",
);
Ok(vec![Message::User {
content: other_content.map(|content| match content {
message::UserContent::Text(message::Text { text }) => {
UserContent::Text { text }
}
message::UserContent::Image(message::Image {
data, detail, ..
}) => UserContent::Image {
image_url: ImageUrl {
url: data,
detail: detail.unwrap_or_default(),
},
},
message::UserContent::Document(message::Document { data, .. }) => {
UserContent::Text { text: data }
}
message::UserContent::Audio(message::Audio {
data,
media_type,
..
}) => UserContent::Audio {
input_audio: InputAudio {
data,
format: match media_type {
Some(media_type) => media_type,
None => AudioMediaType::MP3,
},
},
},
_ => unreachable!(),
}),
name: None,
}])
}
}
message::Message::Assistant { content } => {
let (text_content, tool_calls) = content.into_iter().fold(
(Vec::new(), Vec::new()),
|(mut texts, mut tools), content| {
match content {
message::AssistantContent::Text(text) => texts.push(text),
message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
}
(texts, tools)
},
);
// `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
// so either `content` or `tool_calls` will have some content.
Ok(vec![Message::Assistant {
content: text_content
.into_iter()
.map(|content| content.text.into())
.collect::<Vec<_>>(),
refusal: None,
audio: None,
name: None,
tool_calls: tool_calls
.into_iter()
.map(|tool_call| tool_call.into())
.collect::<Vec<_>>(),
}])
}
}
}
}
impl From<message::ToolCall> for ToolCall {
fn from(tool_call: message::ToolCall) -> Self {
Self {
id: tool_call.id,
r#type: ToolType::default(),
function: Function {
name: tool_call.function.name,
arguments: tool_call.function.arguments,
},
}
}
}
impl From<ToolCall> for message::ToolCall {
fn from(tool_call: ToolCall) -> Self {
Self {
id: tool_call.id,
function: message::ToolFunction {
name: tool_call.function.name,
arguments: tool_call.function.arguments,
},
}
}
}
impl TryFrom<Message> for message::Message {
type Error = message::MessageError;
fn try_from(message: Message) -> Result<Self, Self::Error> {
Ok(match message {
Message::User { content, .. } => message::Message::User {
content: content.map(|content| content.into()),
},
Message::Assistant {
content,
tool_calls,
..
} => {
let mut content = content
.into_iter()
.map(|content| match content {
AssistantContent::Text { text } => message::AssistantContent::text(text),
// TODO: Currently, refusals are converted into text, but should be
// investigated for generalization.
AssistantContent::Refusal { refusal } => {
message::AssistantContent::text(refusal)
}
})
.collect::<Vec<_>>();
content.extend(
tool_calls
.into_iter()
.map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
.collect::<Result<Vec<_>, _>>()?,
);
message::Message::Assistant {
content: OneOrMany::many(content).map_err(|_| {
message::MessageError::ConversionError(
"Neither `content` nor `tool_calls` was provided to the Message"
.to_owned(),
)
})?,
}
}
Message::ToolResult {
tool_call_id,
content,
} => message::Message::User {
content: OneOrMany::one(message::UserContent::tool_result(
tool_call_id,
content.map(|content| message::ToolResultContent::text(content.text)),
)),
},
// System messages should get stripped out when converting message's, this is just a
// stop gap to avoid obnoxious error handling or panic occuring.
Message::System { content, .. } => message::Message::User {
content: content.map(|content| message::UserContent::text(content.text)),
},
})
}
}
impl From<UserContent> for message::UserContent {
fn from(content: UserContent) -> Self {
match content {
UserContent::Text { text } => message::UserContent::text(text),
UserContent::Image { image_url } => message::UserContent::image(
image_url.url,
Some(message::ContentFormat::default()),
None,
Some(image_url.detail),
),
UserContent::Audio { input_audio } => message::UserContent::audio(
input_audio.data,
Some(message::ContentFormat::default()),
Some(input_audio.format),
),
}
}
}
impl From<String> for UserContent {
fn from(s: String) -> Self {
UserContent::Text { text: s }
}
}
impl FromStr for UserContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(UserContent::Text {
text: s.to_string(),
})
}
}
impl From<String> for AssistantContent {
fn from(s: String) -> Self {
AssistantContent::Text { text: s }
}
}
impl FromStr for AssistantContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(AssistantContent::Text {
text: s.to_string(),
})
}
}
impl From<String> for SystemContent {
fn from(s: String) -> Self {
SystemContent {
r#type: SystemContentType::default(),
text: s,
}
}
}
impl FromStr for SystemContent {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(SystemContent {
r#type: SystemContentType::default(),
text: s.to_string(),
})
}
}
#[derive(Clone)]
pub struct CompletionModel {
pub(crate) client: Client,
/// Name of the model (e.g.: gpt-3.5-turbo-1106)
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
pub(crate) fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<Value, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message::system(preamble)],
None => vec![],
};
// Convert prompt to user message
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<Message>>, _>>()?
.into_iter()
.flatten()
.collect();
// Combine all messages into a single history
full_history.extend(chat_history);
full_history.extend(prompt);
let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};
// only include temperature if it exists
// because some models don't support temperature
let request = if let Some(temperature) = completion_request.temperature {
json_utils::merge(
request,
json!({
"temperature": temperature,
}),
)
} else {
request
};
let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
};
Ok(request)
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = self.create_completion_request(completion_request)?;
let response = self
.client
.post("/chat/completions")
.json(&request)
.send()
.await?;
if response.status().is_success() {
let t = response.text().await?;
tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}

View File

@ -0,0 +1,118 @@
use super::{ApiErrorResponse, ApiResponse, Client, Usage};
use crate::embeddings;
use crate::embeddings::EmbeddingError;
use serde::Deserialize;
use serde_json::json;
// ================================================================
// OpenAI Embedding API
// ================================================================
/// `text-embedding-3-large` embedding model
pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
/// `text-embedding-3-small` embedding model
pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
/// `text-embedding-ada-002` embedding model
pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Usage,
}
impl From<ApiErrorResponse> for EmbeddingError {
fn from(err: ApiErrorResponse) -> Self {
EmbeddingError::ProviderError(err.message)
}
}
impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
match value {
ApiResponse::Ok(response) => Ok(response),
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f64>,
pub index: usize,
}
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
pub model: String,
ndims: usize,
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
self.ndims
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/embeddings")
.json(&json!({
"model": self.model,
"input": documents,
}))
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI embedding token usage: {}",
response.usage
);
if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
));
}
Ok(response
.data
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect())
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
}

View File

@ -0,0 +1,98 @@
use crate::image_generation;
use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
use crate::providers::openai::{ApiResponse, Client};
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use serde::Deserialize;
use serde_json::json;
// ================================================================
// OpenAI Image Generation API
// ================================================================
pub const DALL_E_2: &str = "dall-e-2";
pub const DALL_E_3: &str = "dall-e-3";
#[derive(Debug, Deserialize)]
pub struct ImageGenerationData {
pub b64_json: String,
}
#[derive(Debug, Deserialize)]
pub struct ImageGenerationResponse {
pub created: i32,
pub data: Vec<ImageGenerationData>,
}
impl TryFrom<ImageGenerationResponse>
for image_generation::ImageGenerationResponse<ImageGenerationResponse>
{
type Error = ImageGenerationError;
fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
let b64_json = value.data[0].b64_json.clone();
let bytes = BASE64_STANDARD
.decode(&b64_json)
.expect("Failed to decode b64");
Ok(image_generation::ImageGenerationResponse {
image: bytes,
response: value,
})
}
}
#[derive(Clone)]
pub struct ImageGenerationModel {
client: Client,
/// Name of the model (e.g.: dall-e-2)
pub model: String,
}
impl ImageGenerationModel {
pub(crate) fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl image_generation::ImageGenerationModel for ImageGenerationModel {
type Response = ImageGenerationResponse;
async fn image_generation(
&self,
generation_request: ImageGenerationRequest,
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
{
let request = json!({
"model": self.model,
"prompt": generation_request.prompt,
"size": format!("{}x{}", generation_request.width, generation_request.height),
"response_format": "b64_json"
});
let response = self
.client
.post("/images/generations")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(ImageGenerationError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
let t = response.text().await?;
match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
}
}
}

View File

@ -0,0 +1,27 @@
//! OpenAI API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::openai;
//!
//! let client = openai::Client::new("YOUR_API_KEY");
//!
//! let gpt4o = client.completion_model(openai::GPT_4O);
//! ```
pub mod client;
pub mod completion;
pub mod embedding;
#[cfg(feature = "image")]
pub mod image_generation;
pub mod streaming;
pub mod transcription;
pub use client::*;
pub use completion::*;
pub use embedding::*;
#[cfg(feature = "image")]
pub use image_generation::*;
pub use streaming::*;
pub use transcription::*;

View File

@ -0,0 +1,184 @@
use super::completion::CompletionModel;
use crate::completion::{CompletionError, CompletionRequest};
use crate::json_utils;
use crate::json_utils::merge;
use crate::streaming;
use crate::streaming::{StreamingCompletionModel, StreamingResult};
use async_stream::stream;
use futures::StreamExt;
use reqwest::RequestBuilder;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
// ================================================================
// OpenAI Completion Streaming API
// ================================================================
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingFunction {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingToolCall {
pub index: usize,
pub function: StreamingFunction,
}
#[derive(Deserialize)]
struct StreamingDelta {
#[serde(default)]
content: Option<String>,
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<StreamingToolCall>,
}
#[derive(Deserialize)]
struct StreamingChoice {
delta: StreamingDelta,
}
#[derive(Deserialize)]
struct StreamingCompletionResponse {
choices: Vec<StreamingChoice>,
}
impl StreamingCompletionModel for CompletionModel {
async fn stream(
&self,
completion_request: CompletionRequest,
) -> Result<StreamingResult, CompletionError> {
let mut request = self.create_completion_request(completion_request)?;
request = merge(request, json!({"stream": true}));
let builder = self.client.post("/chat/completions").json(&request);
send_compatible_streaming_request(builder).await
}
}
pub async fn send_compatible_streaming_request(
request_builder: RequestBuilder,
) -> Result<StreamingResult, CompletionError> {
let response = request_builder.send().await?;
if !response.status().is_success() {
return Err(CompletionError::ProviderError(format!(
"{}: {}",
response.status(),
response.text().await?
)));
}
// Handle OpenAI Compatible SSE chunks
Ok(Box::pin(stream! {
let mut stream = response.bytes_stream();
let mut partial_data = None;
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
while let Some(chunk_result) = stream.next().await {
let chunk = match chunk_result {
Ok(c) => c,
Err(e) => {
yield Err(CompletionError::from(e));
break;
}
};
let text = match String::from_utf8(chunk.to_vec()) {
Ok(t) => t,
Err(e) => {
yield Err(CompletionError::ResponseError(e.to_string()));
break;
}
};
for line in text.lines() {
let mut line = line.to_string();
// If there was a remaining part, concat with current line
if partial_data.is_some() {
line = format!("{}{}", partial_data.unwrap(), line);
partial_data = None;
}
// Otherwise full data line
else {
let Some(data) = line.strip_prefix("data: ") else {
continue;
};
// Partial data, split somewhere in the middle
if !line.ends_with("}") {
partial_data = Some(data.to_string());
} else {
line = data.to_string();
}
}
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
let Ok(data) = data else {
continue;
};
let choice = data.choices.first().expect("Should have at least one choice");
let delta = &choice.delta;
if !delta.tool_calls.is_empty() {
for tool_call in &delta.tool_calls {
let function = tool_call.function.clone();
// Start of tool call
// name: Some(String)
// arguments: None
if function.name.is_some() && function.arguments.is_empty() {
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
}
// Part of tool call
// name: None
// arguments: Some(String)
else if function.name.is_none() && !function.arguments.is_empty() {
let Some((name, arguments)) = calls.get(&tool_call.index) else {
continue;
};
let new_arguments = &tool_call.function.arguments;
let arguments = format!("{}{}", arguments, new_arguments);
calls.insert(tool_call.index, (name.clone(), arguments));
}
// Entire tool call
else {
let name = function.name.unwrap();
let arguments = function.arguments;
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
}
}
}
if let Some(content) = &choice.delta.content {
yield Ok(streaming::StreamingChoice::Message(content.clone()))
}
}
}
for (_, (name, arguments)) in calls {
let Ok(arguments) = serde_json::from_str(&arguments) else {
continue;
};
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
}
}))
}

View File

@ -0,0 +1,105 @@
use crate::providers::openai::{ApiResponse, Client};
use crate::transcription;
use crate::transcription::TranscriptionError;
use reqwest::multipart::Part;
use serde::Deserialize;
// ================================================================
// OpenAI Transcription API
// ================================================================
pub const WHISPER_1: &str = "whisper-1";
#[derive(Debug, Deserialize)]
pub struct TranscriptionResponse {
pub text: String,
}
impl TryFrom<TranscriptionResponse>
for transcription::TranscriptionResponse<TranscriptionResponse>
{
type Error = TranscriptionError;
fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
Ok(transcription::TranscriptionResponse {
text: value.text.clone(),
response: value,
})
}
}
#[derive(Clone)]
pub struct TranscriptionModel {
client: Client,
/// Name of the model (e.g.: gpt-3.5-turbo-1106)
pub model: String,
}
impl TranscriptionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl transcription::TranscriptionModel for TranscriptionModel {
type Response = TranscriptionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn transcription(
&self,
request: transcription::TranscriptionRequest,
) -> Result<
transcription::TranscriptionResponse<Self::Response>,
transcription::TranscriptionError,
> {
let data = request.data;
let mut body = reqwest::multipart::Form::new()
.text("model", self.model.clone())
.text("language", request.language)
.part(
"file",
Part::bytes(data).file_name(request.filename.clone()),
);
if let Some(prompt) = request.prompt {
body = body.text("prompt", prompt.clone());
}
if let Some(ref temperature) = request.temperature {
body = body.text("temperature", temperature.to_string());
}
if let Some(ref additional_params) = request.additional_params {
for (key, value) in additional_params
.as_object()
.expect("Additional Parameters to OpenAI Transcription should be a map")
{
body = body.text(key.to_owned(), value.to_string());
}
}
let response = self
.client
.post("audio/transcriptions")
.multipart(body)
.send()
.await?;
if response.status().is_success() {
match response
.json::<ApiResponse<TranscriptionResponse>>()
.await?
{
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
api_error_response.message,
)),
}
} else {
Err(TranscriptionError::ProviderError(response.text().await?))
}
}
}