mirror of https://github.com/0xplaygrounds/rig
feat: Add image generation to all providers that support it (#357)
* feat: add transcription to all providers that support it * feat: gemini streaming * chore: fix clippy warnings * feat: together streaming * refactor: move transcription preable to constant * chore: fix clippy * fix: enable blob by default * fix: lock methods and imports behind feature flag * chore: fmt * feat: add huggingface transcription, remove feature flag * feat: streaming for most openai type providers * feat: add streaming to remaining providers * chore: formatting * fix: gemini streaming wasn't added to mod * chore: unused imports in example * chore: unused import * chore: run fmt * feat: openai & azure image generation * chore: reorganize a bit * refactor: simplify oai-compatible streaming * feat: huggingface image generation * chore: run fmt & clippy * feat: hyperbolic image generation * refactor: break up openai * fix: lock changes behind feature flag * chore: remove unused import * refactor: change size to width/height, update examples * fix: error with examples * chore: remove unused imports * fix: feature in lib not gated behind flag * fix: used wrong flag * fix: import not behind flag * fix: more flag issues --------- Co-authored-by: yavens <179155341+yavens@users.github.noreply.github.com>
This commit is contained in:
parent
ab8cb89948
commit
70855257c4
|
@ -48,6 +48,7 @@ serde_path_to_error = "0.1.16"
|
|||
[features]
|
||||
default = ["reqwest/default"]
|
||||
all = ["derive", "pdf", "rayon"]
|
||||
image = []
|
||||
derive = ["dep:rig-derive"]
|
||||
pdf = ["dep:lopdf"]
|
||||
epub = ["dep:epub", "dep:quick-xml"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use rig::providers::openai::Client;
|
||||
use rig::providers::openai::client::Client;
|
||||
use schemars::JsonSchema;
|
||||
use std::env;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use std::env;
|
||||
|
||||
use rig::{completion::Prompt, providers::openai::Client};
|
||||
use rig::completion::Prompt;
|
||||
use rig::providers::openai::client::Client;
|
||||
use schemars::JsonSchema;
|
||||
|
||||
#[derive(serde::Deserialize, JsonSchema, serde::Serialize, Debug)]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::env;
|
||||
|
||||
use rig::providers::openai::Client;
|
||||
use rig::providers::openai::client::Client;
|
||||
|
||||
use schemars::JsonSchema;
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use std::env;
|
||||
|
||||
use rig::pipeline::agent_ops::extract;
|
||||
use rig::providers::openai::client::Client;
|
||||
use rig::{
|
||||
parallel,
|
||||
pipeline::{self, passthrough, Op},
|
||||
providers::openai::Client,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use std::env;
|
||||
|
||||
use rig::{
|
||||
pipeline::{self, Op},
|
||||
providers::openai::Client,
|
||||
};
|
||||
use rig::pipeline::{self, Op};
|
||||
use rig::providers::openai::client::Client;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use std::env;
|
||||
|
||||
use rig::{
|
||||
pipeline::{self, Op, TryOp},
|
||||
providers::openai::Client,
|
||||
};
|
||||
use rig::pipeline::{self, Op, TryOp};
|
||||
use rig::providers::openai::client::Client;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
use rig::image_generation::ImageGenerationModel;
|
||||
use rig::providers::huggingface;
|
||||
use std::env::args;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
const DEFAULT_PATH: &str = "./output.png";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let arguments: Vec<String> = args().collect();
|
||||
|
||||
let path = if arguments.len() > 1 {
|
||||
arguments[1].clone()
|
||||
} else {
|
||||
DEFAULT_PATH.to_string()
|
||||
};
|
||||
|
||||
let path = Path::new(&path);
|
||||
|
||||
let mut file = File::create_new(path).expect("Failed to create file");
|
||||
|
||||
let huggingface = huggingface::Client::from_env();
|
||||
|
||||
let dalle = huggingface.image_generation_model(huggingface::STABLE_DIFFUSION_3);
|
||||
|
||||
let response = dalle
|
||||
.image_generation_request()
|
||||
.prompt("A castle sitting upon a large mountain, overlooking the water.")
|
||||
.width(1024)
|
||||
.height(1024)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to generate image");
|
||||
|
||||
let _ = file.write(&response.image);
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
use rig::image_generation::ImageGenerationModel;
|
||||
use rig::providers::hyperbolic;
|
||||
use std::env::args;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
const DEFAULT_PATH: &str = "./output.png";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let arguments: Vec<String> = args().collect();
|
||||
|
||||
let path = if arguments.len() > 1 {
|
||||
arguments[1].clone()
|
||||
} else {
|
||||
DEFAULT_PATH.to_string()
|
||||
};
|
||||
|
||||
let path = Path::new(&path);
|
||||
let mut file = File::create_new(path).expect("Failed to create file");
|
||||
|
||||
let hyperbolic = hyperbolic::Client::from_env();
|
||||
|
||||
let stable_diffusion = hyperbolic.image_generation_model(hyperbolic::SDXL_TURBO);
|
||||
|
||||
let response = stable_diffusion
|
||||
.image_generation_request()
|
||||
.prompt("A castle sitting upon a large mountain, overlooking the water.")
|
||||
.width(1024)
|
||||
.height(1024)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to generate image");
|
||||
|
||||
let _ = file.write(&response.image);
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
use rig::image_generation::ImageGenerationModel;
|
||||
use rig::providers::openai;
|
||||
use std::env::args;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
|
||||
const DEFAULT_PATH: &str = "./output.png";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let arguments: Vec<String> = args().collect();
|
||||
|
||||
let path = if arguments.len() > 1 {
|
||||
arguments[1].clone()
|
||||
} else {
|
||||
DEFAULT_PATH.to_string()
|
||||
};
|
||||
|
||||
let path = Path::new(&path);
|
||||
let mut file = File::create_new(path).expect("Failed to create file");
|
||||
|
||||
let openai = openai::Client::from_env();
|
||||
|
||||
let dalle = openai.image_generation_model(openai::DALL_E_2);
|
||||
|
||||
let response = dalle
|
||||
.image_generation_request()
|
||||
.prompt("A castle sitting upon a large mountain, overlooking the water.")
|
||||
.width(1024)
|
||||
.height(1024)
|
||||
.send()
|
||||
.await
|
||||
.expect("Failed to generate image");
|
||||
|
||||
let _ = file.write(&response.image);
|
||||
}
|
|
@ -1,11 +1,9 @@
|
|||
use std::{env, vec};
|
||||
|
||||
use rig::providers::openai::client::Client;
|
||||
use rig::{
|
||||
completion::Prompt,
|
||||
embeddings::EmbeddingsBuilder,
|
||||
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
|
||||
vector_store::in_memory_store::InMemoryVectorStore,
|
||||
Embed,
|
||||
completion::Prompt, embeddings::EmbeddingsBuilder, providers::openai::TEXT_EMBEDDING_ADA_002,
|
||||
vector_store::in_memory_store::InMemoryVectorStore, Embed,
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
use std::env;
|
||||
|
||||
use rig::providers::openai::client::Client;
|
||||
use rig::{
|
||||
embeddings::EmbeddingsBuilder,
|
||||
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
|
||||
providers::openai::TEXT_EMBEDDING_ADA_002,
|
||||
vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex},
|
||||
Embed,
|
||||
};
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ImageGenerationError {
|
||||
/// Http error (e.g.: connection error, timeout, etc.)
|
||||
#[error("HttpError: {0}")]
|
||||
HttpError(#[from] reqwest::Error),
|
||||
|
||||
/// Json error (e.g.: serialization, deserialization)
|
||||
#[error("JsonError: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
|
||||
/// Error building the transcription request
|
||||
#[error("RequestError: {0}")]
|
||||
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
|
||||
/// Error parsing the transcription response
|
||||
#[error("ResponseError: {0}")]
|
||||
ResponseError(String),
|
||||
|
||||
/// Error returned by the transcription model provider
|
||||
#[error("ProviderError: {0}")]
|
||||
ProviderError(String),
|
||||
}
|
||||
pub trait ImageGeneration<M: ImageGenerationModel> {
|
||||
/// Generates a transcription request builder for the given `file`.
|
||||
/// This function is meant to be called by the user to further customize the
|
||||
/// request at transcription time before sending it.
|
||||
///
|
||||
/// ❗IMPORTANT: The type that implements this trait might have already
|
||||
/// populated fields in the builder (the exact fields depend on the type).
|
||||
/// For fields that have already been set by the model, calling the corresponding
|
||||
/// method on the builder will overwrite the value set by the model.
|
||||
fn image_generation(
|
||||
&self,
|
||||
prompt: &str,
|
||||
size: &(u32, u32),
|
||||
) -> impl std::future::Future<
|
||||
Output = Result<ImageGenerationRequestBuilder<M>, ImageGenerationError>,
|
||||
> + Send;
|
||||
}
|
||||
|
||||
pub struct ImageGenerationResponse<T> {
|
||||
pub image: Vec<u8>,
|
||||
pub response: T,
|
||||
}
|
||||
|
||||
pub trait ImageGenerationModel: Clone + Send + Sync {
|
||||
type Response: Send + Sync;
|
||||
|
||||
fn image_generation(
|
||||
&self,
|
||||
request: ImageGenerationRequest,
|
||||
) -> impl std::future::Future<
|
||||
Output = Result<ImageGenerationResponse<Self::Response>, ImageGenerationError>,
|
||||
> + Send;
|
||||
|
||||
fn image_generation_request(&self) -> ImageGenerationRequestBuilder<Self> {
|
||||
ImageGenerationRequestBuilder::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ImageGenerationRequest {
|
||||
pub prompt: String,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub additional_params: Option<Value>,
|
||||
}
|
||||
|
||||
pub struct ImageGenerationRequestBuilder<M: ImageGenerationModel> {
|
||||
model: M,
|
||||
prompt: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
additional_params: Option<Value>,
|
||||
}
|
||||
|
||||
impl<M: ImageGenerationModel> ImageGenerationRequestBuilder<M> {
|
||||
pub fn new(model: M) -> Self {
|
||||
Self {
|
||||
model,
|
||||
prompt: "".to_string(),
|
||||
height: 256,
|
||||
width: 256,
|
||||
additional_params: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the prompt for the image generation request
|
||||
pub fn prompt(mut self, prompt: &str) -> Self {
|
||||
self.prompt = prompt.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
/// The width of the generated image
|
||||
pub fn width(mut self, width: u32) -> Self {
|
||||
self.width = width;
|
||||
self
|
||||
}
|
||||
|
||||
/// The height of the generated image
|
||||
pub fn height(mut self, height: u32) -> Self {
|
||||
self.height = height;
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds additional parameters to the image generation request.
|
||||
pub fn additional_params(mut self, params: Value) -> Self {
|
||||
self.additional_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> ImageGenerationRequest {
|
||||
ImageGenerationRequest {
|
||||
prompt: self.prompt,
|
||||
width: self.width,
|
||||
height: self.height,
|
||||
additional_params: self.additional_params,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(self) -> Result<ImageGenerationResponse<M::Response>, ImageGenerationError> {
|
||||
let model = self.model.clone();
|
||||
|
||||
model.image_generation(self.build()).await
|
||||
}
|
||||
}
|
|
@ -85,6 +85,8 @@ pub mod cli_chatbot;
|
|||
pub mod completion;
|
||||
pub mod embeddings;
|
||||
pub mod extractor;
|
||||
#[cfg(feature = "image")]
|
||||
pub mod image_generation;
|
||||
pub(crate) mod json_utils;
|
||||
pub mod loaders;
|
||||
pub mod one_or_many;
|
||||
|
|
|
@ -10,6 +10,12 @@
|
|||
//! ```
|
||||
|
||||
use super::openai::{send_compatible_streaming_request, TranscriptionResponse};
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
use super::openai::ImageGenerationResponse;
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
use crate::image_generation::{self, ImageGenerationError, ImageGenerationRequest};
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::{
|
||||
|
@ -157,6 +163,16 @@ impl Client {
|
|||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
fn post_image_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/images/generations?api-version={}",
|
||||
self.azure_endpoint, deployment_id, self.api_version
|
||||
)
|
||||
.replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name.
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
|
@ -642,6 +658,55 @@ impl transcription::TranscriptionModel for TranscriptionModel {
|
|||
}
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Azure OpenAI Image Generation API
|
||||
// ================================================================
|
||||
#[cfg(feature = "image")]
|
||||
#[derive(Clone)]
|
||||
pub struct ImageGenerationModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
#[cfg(feature = "image")]
|
||||
impl image_generation::ImageGenerationModel for ImageGenerationModel {
|
||||
type Response = ImageGenerationResponse;
|
||||
|
||||
async fn image_generation(
|
||||
&self,
|
||||
generation_request: ImageGenerationRequest,
|
||||
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
|
||||
{
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"prompt": generation_request.prompt,
|
||||
"size": format!("{}x{}", generation_request.width, generation_request.height),
|
||||
"response_format": "b64_json"
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post_image_generation(&self.model)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(ImageGenerationError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
let t = response.text().await?;
|
||||
|
||||
match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
|
||||
ApiResponse::Ok(response) => response.try_into(),
|
||||
ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod azure_tests {
|
||||
use super::*;
|
||||
|
|
|
@ -2,7 +2,12 @@ use std::fmt::Display;
|
|||
|
||||
use super::completion::CompletionModel;
|
||||
use crate::agent::AgentBuilder;
|
||||
#[cfg(feature = "image")]
|
||||
use crate::image_generation::ImageGenerationError;
|
||||
#[cfg(feature = "image")]
|
||||
use crate::providers::huggingface::image_generation::ImageGenerationModel;
|
||||
use crate::providers::huggingface::transcription::TranscriptionModel;
|
||||
use crate::transcription::TranscriptionError;
|
||||
|
||||
// ================================================================
|
||||
// Main Huggingface Client
|
||||
|
@ -36,10 +41,27 @@ impl SubProvider {
|
|||
/// Get the transcription endpoint for the SubProvider
|
||||
/// Required because Huggingface Inference requires the model
|
||||
/// in the url and in the request body.
|
||||
pub fn transcription_endpoint(&self, model: &str) -> String {
|
||||
pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
|
||||
match self {
|
||||
SubProvider::HFInference => format!("hf-inference/models/{}", model),
|
||||
_ => panic!("transcription endpoint is not supported yet for {}", self),
|
||||
SubProvider::HFInference => Ok(format!("/{}", model)),
|
||||
_ => Err(TranscriptionError::ProviderError(format!(
|
||||
"transcription endpoint is not supported yet for {}",
|
||||
self
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the image generation endpoint for the SubProvider
|
||||
/// Required because Huggingface Inference requires the model
|
||||
/// in the url and in the request body.
|
||||
#[cfg(feature = "image")]
|
||||
pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
|
||||
match self {
|
||||
SubProvider::HFInference => Ok(format!("/{}", model)),
|
||||
_ => Err(ImageGenerationError::ProviderError(format!(
|
||||
"image generation endpoint is not supported yet for {}",
|
||||
self
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +146,9 @@ pub struct Client {
|
|||
impl Client {
|
||||
/// Create a new Huggingface client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, HUGGINGFACE_API_BASE_URL, SubProvider::HFInference)
|
||||
let base_url =
|
||||
format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
|
||||
Self::from_url(api_key, &base_url, SubProvider::HFInference)
|
||||
}
|
||||
|
||||
/// Create a new Client with the given API key and base API URL.
|
||||
|
@ -193,10 +217,27 @@ impl Client {
|
|||
///
|
||||
/// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
|
||||
/// ```
|
||||
///
|
||||
pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
|
||||
TranscriptionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create a new image generation model with the given name
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::huggingface::{Client, self}
|
||||
///
|
||||
/// // Initialize the Huggingface client
|
||||
/// let client = Client::new("your-huggingface-api-key");
|
||||
///
|
||||
/// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
|
||||
/// ```
|
||||
#[cfg(feature = "image")]
|
||||
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
|
||||
ImageGenerationModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
use super::Client;
|
||||
use crate::image_generation;
|
||||
use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
|
||||
use serde_json::json;
|
||||
|
||||
pub const FLUX_1: &str = "black-forest-labs/FLUX.1-dev";
|
||||
pub const KOLORS: &str = "Kwai-Kolors/Kolors";
|
||||
pub const STABLE_DIFFUSION_3: &str = "stabilityai/stable-diffusion-3-medium-diffusers";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ImageGenerationResponse {
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl TryFrom<ImageGenerationResponse>
|
||||
for image_generation::ImageGenerationResponse<ImageGenerationResponse>
|
||||
{
|
||||
type Error = ImageGenerationError;
|
||||
|
||||
fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
|
||||
Ok(image_generation::ImageGenerationResponse {
|
||||
image: value.data.clone(),
|
||||
response: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ImageGenerationModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl ImageGenerationModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
ImageGenerationModel {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl image_generation::ImageGenerationModel for ImageGenerationModel {
|
||||
type Response = ImageGenerationResponse;
|
||||
|
||||
async fn image_generation(
|
||||
&self,
|
||||
request: ImageGenerationRequest,
|
||||
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
|
||||
{
|
||||
let request = json!({
|
||||
"inputs": request.prompt,
|
||||
"parameters": {
|
||||
"width": request.width,
|
||||
"height": request.height
|
||||
}
|
||||
});
|
||||
|
||||
let route = self
|
||||
.client
|
||||
.sub_provider
|
||||
.image_generation_endpoint(&self.model)?;
|
||||
|
||||
let response = self.client.post(&route).json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(ImageGenerationError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
let data = response.bytes().await?.to_vec();
|
||||
|
||||
ImageGenerationResponse { data }.try_into()
|
||||
}
|
||||
}
|
|
@ -12,6 +12,9 @@
|
|||
|
||||
pub mod client;
|
||||
pub mod completion;
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
pub mod image_generation;
|
||||
pub mod streaming;
|
||||
pub mod transcription;
|
||||
|
||||
|
@ -20,4 +23,7 @@ pub use completion::{
|
|||
GEMMA_2, META_LLAMA_3_1, PHI_4, QWEN2_5, QWEN2_5_CODER, QWEN2_VL, QWEN_QVQ_PREVIEW,
|
||||
SMALLTHINKER_PREVIEW,
|
||||
};
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
pub use image_generation::{FLUX_1, KOLORS, STABLE_DIFFUSION_3};
|
||||
pub use transcription::{WHISPER_LARGE_V3, WHISPER_LARGE_V3_TURBO, WHISPER_SMALL};
|
||||
|
|
|
@ -59,7 +59,10 @@ impl transcription::TranscriptionModel for TranscriptionModel {
|
|||
"inputs": data
|
||||
});
|
||||
|
||||
let route = self.client.sub_provider.transcription_endpoint(&self.model);
|
||||
let route = self
|
||||
.client
|
||||
.sub_provider
|
||||
.transcription_endpoint(&self.model)?;
|
||||
let response = self.client.post(&route).json(&request).send().await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
|
|
|
@ -10,6 +10,9 @@
|
|||
//! ```
|
||||
|
||||
use super::openai::{send_compatible_streaming_request, AssistantContent};
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
use crate::image_generation::{self, ImageGenerationError, ImageGenerationRequest};
|
||||
use crate::json_utils::merge_inplace;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use crate::{
|
||||
|
@ -20,6 +23,11 @@ use crate::{
|
|||
providers::openai::Message,
|
||||
OneOrMany,
|
||||
};
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
#[cfg(feature = "image")]
|
||||
use base64::Engine;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
@ -88,6 +96,22 @@ impl Client {
|
|||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::hyperbolic::{Client, self};
|
||||
///
|
||||
/// // Initialize the Hyperbolic client
|
||||
/// let hyperbolic = Client::new("your-hyperbolic-api-key");
|
||||
///
|
||||
/// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
|
||||
/// ```
|
||||
#[cfg(feature = "image")]
|
||||
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
|
||||
ImageGenerationModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -369,3 +393,106 @@ impl StreamingCompletionModel for CompletionModel {
|
|||
send_compatible_streaming_request(builder).await
|
||||
}
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Hyperbolic Image Generation API
|
||||
// =======================================
|
||||
pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
|
||||
pub const SD2: &str = "SD2";
|
||||
pub const SD1_5: &str = "SD1.5";
|
||||
pub const SSD: &str = "SSD";
|
||||
pub const SDXL_TURBO: &str = "SDXL-turbo";
|
||||
pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
|
||||
pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
#[derive(Clone)]
|
||||
pub struct ImageGenerationModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
impl ImageGenerationModel {
|
||||
fn new(client: Client, model: &str) -> ImageGenerationModel {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct Image {
|
||||
image: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct ImageGenerationResponse {
|
||||
images: Vec<Image>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
impl TryFrom<ImageGenerationResponse>
|
||||
for image_generation::ImageGenerationResponse<ImageGenerationResponse>
|
||||
{
|
||||
type Error = ImageGenerationError;
|
||||
|
||||
fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
|
||||
let data = BASE64_STANDARD
|
||||
.decode(&value.images[0].image)
|
||||
.expect("Could not decode image.");
|
||||
|
||||
Ok(Self {
|
||||
image: data,
|
||||
response: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
impl image_generation::ImageGenerationModel for ImageGenerationModel {
|
||||
type Response = ImageGenerationResponse;
|
||||
|
||||
async fn image_generation(
|
||||
&self,
|
||||
generation_request: ImageGenerationRequest,
|
||||
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
|
||||
{
|
||||
let mut request = json!({
|
||||
"model_name": self.model,
|
||||
"prompt": generation_request.prompt,
|
||||
"height": generation_request.height,
|
||||
"width": generation_request.width,
|
||||
});
|
||||
|
||||
if let Some(params) = generation_request.additional_params {
|
||||
merge_inplace(&mut request, params);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/image/generation")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(ImageGenerationError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status().as_str(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
match response
|
||||
.json::<ApiResponse<ImageGenerationResponse>>()
|
||||
.await?
|
||||
{
|
||||
ApiResponse::Ok(response) => response.try_into(),
|
||||
ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,515 @@
|
|||
use super::completion::CompletionModel;
|
||||
use super::embedding::{
|
||||
EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002,
|
||||
};
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
use super::image_generation::ImageGenerationModel;
|
||||
use super::transcription::TranscriptionModel;
|
||||
use crate::agent::AgentBuilder;
|
||||
use crate::embeddings::EmbeddingsBuilder;
|
||||
use crate::extractor::ExtractorBuilder;
|
||||
|
||||
use crate::Embed;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ================================================================
|
||||
// Main OpenAI Client
|
||||
// ================================================================
|
||||
const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
#[derive(Clone)]
|
||||
pub struct Client {
|
||||
base_url: String,
|
||||
http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Create a new OpenAI client with the given API key.
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self::from_url(api_key, OPENAI_API_BASE_URL)
|
||||
}
|
||||
|
||||
/// Create a new OpenAI client with the given API key and base API URL.
|
||||
pub fn from_url(api_key: &str, base_url: &str) -> Self {
|
||||
Self {
|
||||
base_url: base_url.to_string(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.default_headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", api_key)
|
||||
.parse()
|
||||
.expect("Bearer token should parse"),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.build()
|
||||
.expect("OpenAI reqwest client should build"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
|
||||
/// Panics if the environment variable is not set.
|
||||
pub fn from_env() -> Self {
|
||||
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
|
||||
Self::new(&api_key)
|
||||
}
|
||||
|
||||
pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
|
||||
self.http_client.post(url)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name.
|
||||
/// Note: default embedding dimension of 0 will be used if model is not known.
|
||||
/// If this is the case, it's better to use function `embedding_model_with_ndims`
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE);
|
||||
/// ```
|
||||
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
|
||||
let ndims = match model {
|
||||
TEXT_EMBEDDING_3_LARGE => 3072,
|
||||
TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536,
|
||||
_ => 0,
|
||||
};
|
||||
EmbeddingModel::new(self.clone(), model, ndims)
|
||||
}
|
||||
|
||||
/// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let embedding_model = openai.embedding_model("model-unknown-to-rig", 3072);
|
||||
/// ```
|
||||
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
|
||||
EmbeddingModel::new(self.clone(), model, ndims)
|
||||
}
|
||||
|
||||
/// Create an embedding builder with the given embedding model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let embeddings = openai.embeddings(openai::TEXT_EMBEDDING_3_LARGE)
|
||||
/// .simple_document("doc0", "Hello, world!")
|
||||
/// .simple_document("doc1", "Goodbye, world!")
|
||||
/// .build()
|
||||
/// .await
|
||||
/// .expect("Failed to embed documents");
|
||||
/// ```
|
||||
pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
|
||||
EmbeddingsBuilder::new(self.embedding_model(model))
|
||||
}
|
||||
|
||||
/// Create a completion model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let gpt4 = openai.completion_model(openai::GPT_4);
|
||||
/// ```
|
||||
pub fn completion_model(&self, model: &str) -> CompletionModel {
|
||||
CompletionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an agent builder with the given completion model.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let agent = openai.agent(openai::GPT_4)
|
||||
/// .preamble("You are comedian AI with a mission to make people laugh.")
|
||||
/// .temperature(0.0)
|
||||
/// .build();
|
||||
/// ```
|
||||
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
|
||||
AgentBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create an extractor builder with the given completion model.
|
||||
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
|
||||
&self,
|
||||
model: &str,
|
||||
) -> ExtractorBuilder<T, CompletionModel> {
|
||||
ExtractorBuilder::new(self.completion_model(model))
|
||||
}
|
||||
|
||||
/// Create a transcription model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let gpt4 = openai.transcription_model(openai::WHISPER_1);
|
||||
/// ```
|
||||
pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
|
||||
TranscriptionModel::new(self.clone(), model)
|
||||
}
|
||||
|
||||
/// Create an image generation model with the given name.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rig::providers::openai::{Client, self};
|
||||
///
|
||||
/// // Initialize the OpenAI client
|
||||
/// let openai = Client::new("your-open-ai-api-key");
|
||||
///
|
||||
/// let gpt4 = openai.image_generation_model(openai::DALL_E_3);
|
||||
/// ```
|
||||
#[cfg(feature = "image")]
|
||||
pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
|
||||
ImageGenerationModel::new(self.clone(), model)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Usage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Prompt tokens: {} Total tokens: {}",
|
||||
self.prompt_tokens, self.total_tokens
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ApiErrorResponse {
|
||||
pub(crate) message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum ApiResponse<T> {
|
||||
Ok(T),
|
||||
Err(ApiErrorResponse),
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::message::ImageDetail;
|
||||
use crate::providers::openai::{
|
||||
AssistantContent, Function, ImageUrl, Message, ToolCall, ToolType, UserContent,
|
||||
};
|
||||
use crate::{message, OneOrMany};
|
||||
use serde_path_to_error::deserialize;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_message() {
|
||||
let assistant_message_json = r#"
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n\nHello there, how may I assist you today?"
|
||||
}
|
||||
"#;
|
||||
|
||||
let assistant_message_json2 = r#"
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "\n\nHello there, how may I assist you today?"
|
||||
}
|
||||
],
|
||||
"tool_calls": null
|
||||
}
|
||||
"#;
|
||||
|
||||
let assistant_message_json3 = r#"
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_h89ipqYUjEpCPI6SxspMnoUU",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "subtract",
|
||||
"arguments": "{\"x\": 2, \"y\": 5}"
|
||||
}
|
||||
}
|
||||
],
|
||||
"content": null,
|
||||
"refusal": null
|
||||
}
|
||||
"#;
|
||||
|
||||
let user_message_json = r#"
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "audio",
|
||||
"input_audio": {
|
||||
"data": "...",
|
||||
"format": "mp3"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
|
||||
let assistant_message: Message = {
|
||||
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
|
||||
deserialize(jd).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Deserialization error at {} ({}:{}): {}",
|
||||
err.path(),
|
||||
err.inner().line(),
|
||||
err.inner().column(),
|
||||
err
|
||||
);
|
||||
})
|
||||
};
|
||||
|
||||
let assistant_message2: Message = {
|
||||
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
|
||||
deserialize(jd).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Deserialization error at {} ({}:{}): {}",
|
||||
err.path(),
|
||||
err.inner().line(),
|
||||
err.inner().column(),
|
||||
err
|
||||
);
|
||||
})
|
||||
};
|
||||
|
||||
let assistant_message3: Message = {
|
||||
let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
|
||||
&mut serde_json::Deserializer::from_str(assistant_message_json3);
|
||||
deserialize(jd).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Deserialization error at {} ({}:{}): {}",
|
||||
err.path(),
|
||||
err.inner().line(),
|
||||
err.inner().column(),
|
||||
err
|
||||
);
|
||||
})
|
||||
};
|
||||
|
||||
let user_message: Message = {
|
||||
let jd = &mut serde_json::Deserializer::from_str(user_message_json);
|
||||
deserialize(jd).unwrap_or_else(|err| {
|
||||
panic!(
|
||||
"Deserialization error at {} ({}:{}): {}",
|
||||
err.path(),
|
||||
err.inner().line(),
|
||||
err.inner().column(),
|
||||
err
|
||||
);
|
||||
})
|
||||
};
|
||||
|
||||
match assistant_message {
|
||||
Message::Assistant { content, .. } => {
|
||||
assert_eq!(
|
||||
content[0],
|
||||
AssistantContent::Text {
|
||||
text: "\n\nHello there, how may I assist you today?".to_string()
|
||||
}
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected assistant message"),
|
||||
}
|
||||
|
||||
match assistant_message2 {
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(
|
||||
content[0],
|
||||
AssistantContent::Text {
|
||||
text: "\n\nHello there, how may I assist you today?".to_string()
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(tool_calls, vec![]);
|
||||
}
|
||||
_ => panic!("Expected assistant message"),
|
||||
}
|
||||
|
||||
match assistant_message3 {
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
refusal,
|
||||
..
|
||||
} => {
|
||||
assert!(content.is_empty());
|
||||
assert!(refusal.is_none());
|
||||
assert_eq!(
|
||||
tool_calls[0],
|
||||
ToolCall {
|
||||
id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
|
||||
r#type: ToolType::Function,
|
||||
function: Function {
|
||||
name: "subtract".to_string(),
|
||||
arguments: serde_json::json!({"x": 2, "y": 5}),
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected assistant message"),
|
||||
}
|
||||
|
||||
match user_message {
|
||||
Message::User { content, .. } => {
|
||||
let (first, second) = {
|
||||
let mut iter = content.into_iter();
|
||||
(iter.next().unwrap(), iter.next().unwrap())
|
||||
};
|
||||
assert_eq!(
|
||||
first,
|
||||
UserContent::Text {
|
||||
text: "What's in this image?".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(second, UserContent::Image { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), detail: ImageDetail::default() } });
|
||||
}
|
||||
_ => panic!("Expected user message"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_to_message_conversion() {
|
||||
let user_message = message::Message::User {
|
||||
content: OneOrMany::one(message::UserContent::text("Hello")),
|
||||
};
|
||||
|
||||
let assistant_message = message::Message::Assistant {
|
||||
content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
|
||||
};
|
||||
|
||||
let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
|
||||
let converted_assistant_message: Vec<Message> =
|
||||
assistant_message.clone().try_into().unwrap();
|
||||
|
||||
match converted_user_message[0].clone() {
|
||||
Message::User { content, .. } => {
|
||||
assert_eq!(
|
||||
content.first(),
|
||||
UserContent::Text {
|
||||
text: "Hello".to_string()
|
||||
}
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected user message"),
|
||||
}
|
||||
|
||||
match converted_assistant_message[0].clone() {
|
||||
Message::Assistant { content, .. } => {
|
||||
assert_eq!(
|
||||
content[0].clone(),
|
||||
AssistantContent::Text {
|
||||
text: "Hi there!".to_string()
|
||||
}
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected assistant message"),
|
||||
}
|
||||
|
||||
let original_user_message: message::Message =
|
||||
converted_user_message[0].clone().try_into().unwrap();
|
||||
let original_assistant_message: message::Message =
|
||||
converted_assistant_message[0].clone().try_into().unwrap();
|
||||
|
||||
assert_eq!(original_user_message, user_message);
|
||||
assert_eq!(original_assistant_message, assistant_message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_from_message_conversion() {
|
||||
let user_message = Message::User {
|
||||
content: OneOrMany::one(UserContent::Text {
|
||||
text: "Hello".to_string(),
|
||||
}),
|
||||
name: None,
|
||||
};
|
||||
|
||||
let assistant_message = Message::Assistant {
|
||||
content: vec![AssistantContent::Text {
|
||||
text: "Hi there!".to_string(),
|
||||
}],
|
||||
refusal: None,
|
||||
audio: None,
|
||||
name: None,
|
||||
tool_calls: vec![],
|
||||
};
|
||||
|
||||
let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
|
||||
let converted_assistant_message: message::Message =
|
||||
assistant_message.clone().try_into().unwrap();
|
||||
|
||||
match converted_user_message.clone() {
|
||||
message::Message::User { content } => {
|
||||
assert_eq!(content.first(), message::UserContent::text("Hello"));
|
||||
}
|
||||
_ => panic!("Expected user message"),
|
||||
}
|
||||
|
||||
match converted_assistant_message.clone() {
|
||||
message::Message::Assistant { content } => {
|
||||
assert_eq!(
|
||||
content.first(),
|
||||
message::AssistantContent::text("Hi there!")
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected assistant message"),
|
||||
}
|
||||
|
||||
let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
|
||||
let original_assistant_message: Vec<Message> =
|
||||
converted_assistant_message.try_into().unwrap();
|
||||
|
||||
assert_eq!(original_user_message[0], user_message);
|
||||
assert_eq!(original_assistant_message[0], assistant_message);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,704 @@
|
|||
// ================================================================
|
||||
// OpenAI Completion API
|
||||
// ================================================================
|
||||
|
||||
use super::{ApiErrorResponse, ApiResponse, Client, Usage};
|
||||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::message::{AudioMediaType, ImageDetail};
|
||||
use crate::one_or_many::string_or_one_or_many;
|
||||
use crate::{completion, json_utils, message, OneOrMany};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::convert::Infallible;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// `o3-mini` completion model
|
||||
pub const O3_MINI: &str = "o3-mini";
|
||||
/// `o3-mini-2025-01-31` completion model
|
||||
pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31";
|
||||
/// 'o1' completion model
|
||||
pub const O1: &str = "o1";
|
||||
/// `o1-2024-12-17` completion model
|
||||
pub const O1_2024_12_17: &str = "o1-2024-12-17";
|
||||
/// `o1-preview` completion model
|
||||
pub const O1_PREVIEW: &str = "o1-preview";
|
||||
/// `o1-preview-2024-09-12` completion model
|
||||
pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12";
|
||||
/// `o1-mini completion model
|
||||
pub const O1_MINI: &str = "o1-mini";
|
||||
/// `o1-mini-2024-09-12` completion model
|
||||
pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12";
|
||||
/// `gpt-4.5-preview` completion model
|
||||
pub const GPT_4_5_PREVIEW: &str = "gpt-4.5-preview";
|
||||
/// `gpt-4.5-preview-2025-02-27` completion model
|
||||
pub const GPT_4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27";
|
||||
/// `gpt-4o` completion model
|
||||
pub const GPT_4O: &str = "gpt-4o";
|
||||
/// `gpt-4o-mini` completion model
|
||||
pub const GPT_4O_MINI: &str = "gpt-4o-mini";
|
||||
/// `gpt-4o-2024-05-13` completion model
|
||||
pub const GPT_4O_2024_05_13: &str = "gpt-4o-2024-05-13";
|
||||
/// `gpt-4-turbo` completion model
|
||||
pub const GPT_4_TURBO: &str = "gpt-4-turbo";
|
||||
/// `gpt-4-turbo-2024-04-09` completion model
|
||||
pub const GPT_4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09";
|
||||
/// `gpt-4-turbo-preview` completion model
|
||||
pub const GPT_4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview";
|
||||
/// `gpt-4-0125-preview` completion model
|
||||
pub const GPT_4_0125_PREVIEW: &str = "gpt-4-0125-preview";
|
||||
/// `gpt-4-1106-preview` completion model
|
||||
pub const GPT_4_1106_PREVIEW: &str = "gpt-4-1106-preview";
|
||||
/// `gpt-4-vision-preview` completion model
|
||||
pub const GPT_4_VISION_PREVIEW: &str = "gpt-4-vision-preview";
|
||||
/// `gpt-4-1106-vision-preview` completion model
|
||||
pub const GPT_4_1106_VISION_PREVIEW: &str = "gpt-4-1106-vision-preview";
|
||||
/// `gpt-4` completion model
|
||||
pub const GPT_4: &str = "gpt-4";
|
||||
/// `gpt-4-0613` completion model
|
||||
pub const GPT_4_0613: &str = "gpt-4-0613";
|
||||
/// `gpt-4-32k` completion model
|
||||
pub const GPT_4_32K: &str = "gpt-4-32k";
|
||||
/// `gpt-4-32k-0613` completion model
|
||||
pub const GPT_4_32K_0613: &str = "gpt-4-32k-0613";
|
||||
/// `gpt-3.5-turbo` completion model
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
/// `gpt-3.5-turbo-0125` completion model
|
||||
pub const GPT_35_TURBO_0125: &str = "gpt-3.5-turbo-0125";
|
||||
/// `gpt-3.5-turbo-1106` completion model
|
||||
pub const GPT_35_TURBO_1106: &str = "gpt-3.5-turbo-1106";
|
||||
/// `gpt-3.5-turbo-instruct` completion model
|
||||
pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
impl From<ApiErrorResponse> for CompletionError {
|
||||
fn from(err: ApiErrorResponse) -> Self {
|
||||
CompletionError::ProviderError(err.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
|
||||
type Error = CompletionError;
|
||||
|
||||
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
|
||||
let choice = response.choices.first().ok_or_else(|| {
|
||||
CompletionError::ResponseError("Response contained no choices".to_owned())
|
||||
})?;
|
||||
|
||||
let content = match &choice.message {
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => {
|
||||
let mut content = content
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
let s = match c {
|
||||
AssistantContent::Text { text } => text,
|
||||
AssistantContent::Refusal { refusal } => refusal,
|
||||
};
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(completion::AssistantContent::text(s))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
content.extend(
|
||||
tool_calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
completion::AssistantContent::tool_call(
|
||||
&call.id,
|
||||
&call.function.name,
|
||||
call.function.arguments.clone(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
Ok(content)
|
||||
}
|
||||
_ => Err(CompletionError::ResponseError(
|
||||
"Response did not contain a valid message or tool call".into(),
|
||||
)),
|
||||
}?;
|
||||
|
||||
let choice = OneOrMany::many(content).map_err(|_| {
|
||||
CompletionError::ResponseError(
|
||||
"Response contained no message or tool call (empty)".to_owned(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(completion::CompletionResponse {
|
||||
choice,
|
||||
raw_response: response,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
pub logprobs: Option<serde_json::Value>,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum Message {
|
||||
System {
|
||||
#[serde(deserialize_with = "string_or_one_or_many")]
|
||||
content: OneOrMany<SystemContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
User {
|
||||
#[serde(deserialize_with = "string_or_one_or_many")]
|
||||
content: OneOrMany<UserContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
Assistant {
|
||||
#[serde(default, deserialize_with = "json_utils::string_or_vec")]
|
||||
content: Vec<AssistantContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
refusal: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
audio: Option<AudioAssistant>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(
|
||||
default,
|
||||
deserialize_with = "json_utils::null_or_vec",
|
||||
skip_serializing_if = "Vec::is_empty"
|
||||
)]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
#[serde(rename = "tool")]
|
||||
ToolResult {
|
||||
tool_call_id: String,
|
||||
content: OneOrMany<ToolResultContent>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn system(content: &str) -> Self {
|
||||
Message::System {
|
||||
content: OneOrMany::one(content.to_owned().into()),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct AudioAssistant {
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct SystemContent {
|
||||
#[serde(default)]
|
||||
r#type: SystemContentType,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SystemContentType {
|
||||
#[default]
|
||||
Text,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum AssistantContent {
|
||||
Text { text: String },
|
||||
Refusal { refusal: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum UserContent {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
#[serde(rename = "image_url")]
|
||||
Image {
|
||||
image_url: ImageUrl,
|
||||
},
|
||||
Audio {
|
||||
input_audio: InputAudio,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(default)]
|
||||
pub detail: ImageDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct InputAudio {
|
||||
pub data: String,
|
||||
pub format: AudioMediaType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct ToolResultContent {
|
||||
#[serde(default)]
|
||||
r#type: ToolResultContentType,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ToolResultContentType {
|
||||
#[default]
|
||||
Text,
|
||||
}
|
||||
|
||||
impl FromStr for ToolResultContent {
|
||||
type Err = Infallible;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(s.to_owned().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ToolResultContent {
|
||||
fn from(s: String) -> Self {
|
||||
ToolResultContent {
|
||||
r#type: ToolResultContentType::default(),
|
||||
text: s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub r#type: ToolType,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ToolType {
|
||||
#[default]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct ToolDefinition {
|
||||
pub r#type: String,
|
||||
pub function: completion::ToolDefinition,
|
||||
}
|
||||
|
||||
impl From<completion::ToolDefinition> for ToolDefinition {
|
||||
fn from(tool: completion::ToolDefinition) -> Self {
|
||||
Self {
|
||||
r#type: "function".into(),
|
||||
function: tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
#[serde(with = "json_utils::stringified_json")]
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
impl TryFrom<message::Message> for Vec<Message> {
|
||||
type Error = message::MessageError;
|
||||
|
||||
fn try_from(message: message::Message) -> Result<Self, Self::Error> {
|
||||
match message {
|
||||
message::Message::User { content } => {
|
||||
let (tool_results, other_content): (Vec<_>, Vec<_>) = content
|
||||
.into_iter()
|
||||
.partition(|content| matches!(content, message::UserContent::ToolResult(_)));
|
||||
|
||||
// If there are messages with both tool results and user content, openai will only
|
||||
// handle tool results. It's unlikely that there will be both.
|
||||
if !tool_results.is_empty() {
|
||||
tool_results
|
||||
.into_iter()
|
||||
.map(|content| match content {
|
||||
message::UserContent::ToolResult(message::ToolResult {
|
||||
id,
|
||||
content,
|
||||
}) => Ok::<_, message::MessageError>(Message::ToolResult {
|
||||
tool_call_id: id,
|
||||
content: content.try_map(|content| match content {
|
||||
message::ToolResultContent::Text(message::Text { text }) => {
|
||||
Ok(text.into())
|
||||
}
|
||||
_ => Err(message::MessageError::ConversionError(
|
||||
"Tool result content does not support non-text".into(),
|
||||
)),
|
||||
})?,
|
||||
}),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
} else {
|
||||
let other_content = OneOrMany::many(other_content).expect(
|
||||
"There must be other content here if there were no tool result content",
|
||||
);
|
||||
|
||||
Ok(vec![Message::User {
|
||||
content: other_content.map(|content| match content {
|
||||
message::UserContent::Text(message::Text { text }) => {
|
||||
UserContent::Text { text }
|
||||
}
|
||||
message::UserContent::Image(message::Image {
|
||||
data, detail, ..
|
||||
}) => UserContent::Image {
|
||||
image_url: ImageUrl {
|
||||
url: data,
|
||||
detail: detail.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
message::UserContent::Document(message::Document { data, .. }) => {
|
||||
UserContent::Text { text: data }
|
||||
}
|
||||
message::UserContent::Audio(message::Audio {
|
||||
data,
|
||||
media_type,
|
||||
..
|
||||
}) => UserContent::Audio {
|
||||
input_audio: InputAudio {
|
||||
data,
|
||||
format: match media_type {
|
||||
Some(media_type) => media_type,
|
||||
None => AudioMediaType::MP3,
|
||||
},
|
||||
},
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}),
|
||||
name: None,
|
||||
}])
|
||||
}
|
||||
}
|
||||
message::Message::Assistant { content } => {
|
||||
let (text_content, tool_calls) = content.into_iter().fold(
|
||||
(Vec::new(), Vec::new()),
|
||||
|(mut texts, mut tools), content| {
|
||||
match content {
|
||||
message::AssistantContent::Text(text) => texts.push(text),
|
||||
message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
|
||||
}
|
||||
(texts, tools)
|
||||
},
|
||||
);
|
||||
|
||||
// `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
|
||||
// so either `content` or `tool_calls` will have some content.
|
||||
Ok(vec![Message::Assistant {
|
||||
content: text_content
|
||||
.into_iter()
|
||||
.map(|content| content.text.into())
|
||||
.collect::<Vec<_>>(),
|
||||
refusal: None,
|
||||
audio: None,
|
||||
name: None,
|
||||
tool_calls: tool_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| tool_call.into())
|
||||
.collect::<Vec<_>>(),
|
||||
}])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<message::ToolCall> for ToolCall {
|
||||
fn from(tool_call: message::ToolCall) -> Self {
|
||||
Self {
|
||||
id: tool_call.id,
|
||||
r#type: ToolType::default(),
|
||||
function: Function {
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ToolCall> for message::ToolCall {
|
||||
fn from(tool_call: ToolCall) -> Self {
|
||||
Self {
|
||||
id: tool_call.id,
|
||||
function: message::ToolFunction {
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Message> for message::Message {
|
||||
type Error = message::MessageError;
|
||||
|
||||
fn try_from(message: Message) -> Result<Self, Self::Error> {
|
||||
Ok(match message {
|
||||
Message::User { content, .. } => message::Message::User {
|
||||
content: content.map(|content| content.into()),
|
||||
},
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => {
|
||||
let mut content = content
|
||||
.into_iter()
|
||||
.map(|content| match content {
|
||||
AssistantContent::Text { text } => message::AssistantContent::text(text),
|
||||
|
||||
// TODO: Currently, refusals are converted into text, but should be
|
||||
// investigated for generalization.
|
||||
AssistantContent::Refusal { refusal } => {
|
||||
message::AssistantContent::text(refusal)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
content.extend(
|
||||
tool_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
);
|
||||
|
||||
message::Message::Assistant {
|
||||
content: OneOrMany::many(content).map_err(|_| {
|
||||
message::MessageError::ConversionError(
|
||||
"Neither `content` nor `tool_calls` was provided to the Message"
|
||||
.to_owned(),
|
||||
)
|
||||
})?,
|
||||
}
|
||||
}
|
||||
|
||||
Message::ToolResult {
|
||||
tool_call_id,
|
||||
content,
|
||||
} => message::Message::User {
|
||||
content: OneOrMany::one(message::UserContent::tool_result(
|
||||
tool_call_id,
|
||||
content.map(|content| message::ToolResultContent::text(content.text)),
|
||||
)),
|
||||
},
|
||||
|
||||
// System messages should get stripped out when converting message's, this is just a
|
||||
// stop gap to avoid obnoxious error handling or panic occuring.
|
||||
Message::System { content, .. } => message::Message::User {
|
||||
content: content.map(|content| message::UserContent::text(content.text)),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserContent> for message::UserContent {
|
||||
fn from(content: UserContent) -> Self {
|
||||
match content {
|
||||
UserContent::Text { text } => message::UserContent::text(text),
|
||||
UserContent::Image { image_url } => message::UserContent::image(
|
||||
image_url.url,
|
||||
Some(message::ContentFormat::default()),
|
||||
None,
|
||||
Some(image_url.detail),
|
||||
),
|
||||
UserContent::Audio { input_audio } => message::UserContent::audio(
|
||||
input_audio.data,
|
||||
Some(message::ContentFormat::default()),
|
||||
Some(input_audio.format),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for UserContent {
|
||||
fn from(s: String) -> Self {
|
||||
UserContent::Text { text: s }
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for UserContent {
|
||||
type Err = Infallible;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(UserContent::Text {
|
||||
text: s.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for AssistantContent {
|
||||
fn from(s: String) -> Self {
|
||||
AssistantContent::Text { text: s }
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for AssistantContent {
|
||||
type Err = Infallible;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(AssistantContent::Text {
|
||||
text: s.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
impl From<String> for SystemContent {
|
||||
fn from(s: String) -> Self {
|
||||
SystemContent {
|
||||
r#type: SystemContentType::default(),
|
||||
text: s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SystemContent {
|
||||
type Err = Infallible;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(SystemContent {
|
||||
r#type: SystemContentType::default(),
|
||||
text: s.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompletionModel {
|
||||
pub(crate) client: Client,
|
||||
/// Name of the model (e.g.: gpt-3.5-turbo-1106)
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl CompletionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_completion_request(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<Value, CompletionError> {
|
||||
// Add preamble to chat history (if available)
|
||||
let mut full_history: Vec<Message> = match &completion_request.preamble {
|
||||
Some(preamble) => vec![Message::system(preamble)],
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
// Convert prompt to user message
|
||||
let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
|
||||
|
||||
// Convert existing chat history
|
||||
let chat_history: Vec<Message> = completion_request
|
||||
.chat_history
|
||||
.into_iter()
|
||||
.map(|message| message.try_into())
|
||||
.collect::<Result<Vec<Vec<Message>>, _>>()?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// Combine all messages into a single history
|
||||
full_history.extend(chat_history);
|
||||
full_history.extend(prompt);
|
||||
|
||||
let request = if completion_request.tools.is_empty() {
|
||||
json!({
|
||||
"model": self.model,
|
||||
"messages": full_history,
|
||||
|
||||
})
|
||||
} else {
|
||||
json!({
|
||||
"model": self.model,
|
||||
"messages": full_history,
|
||||
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
|
||||
"tool_choice": "auto",
|
||||
})
|
||||
};
|
||||
|
||||
// only include temperature if it exists
|
||||
// because some models don't support temperature
|
||||
let request = if let Some(temperature) = completion_request.temperature {
|
||||
json_utils::merge(
|
||||
request,
|
||||
json!({
|
||||
"temperature": temperature,
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
||||
let request = if let Some(params) = completion_request.additional_params {
|
||||
json_utils::merge(request, params)
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
}
|
||||
|
||||
impl completion::CompletionModel for CompletionModel {
|
||||
type Response = CompletionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn completion(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
|
||||
let request = self.create_completion_request(completion_request)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/chat/completions")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let t = response.text().await?;
|
||||
tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
|
||||
|
||||
match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
|
||||
ApiResponse::Ok(response) => {
|
||||
tracing::info!(target: "rig",
|
||||
"OpenAI completion token usage: {:?}",
|
||||
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
|
||||
);
|
||||
response.try_into()
|
||||
}
|
||||
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
|
||||
}
|
||||
} else {
|
||||
Err(CompletionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
use super::{ApiErrorResponse, ApiResponse, Client, Usage};
|
||||
use crate::embeddings;
|
||||
use crate::embeddings::EmbeddingError;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Embedding API
|
||||
// ================================================================
|
||||
/// `text-embedding-3-large` embedding model
|
||||
pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
|
||||
/// `text-embedding-3-small` embedding model
|
||||
pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
|
||||
/// `text-embedding-ada-002` embedding model
|
||||
pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<EmbeddingData>,
|
||||
pub model: String,
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
impl From<ApiErrorResponse> for EmbeddingError {
|
||||
fn from(err: ApiErrorResponse) -> Self {
|
||||
EmbeddingError::ProviderError(err.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
|
||||
fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
|
||||
match value {
|
||||
ApiResponse::Ok(response) => Ok(response),
|
||||
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct EmbeddingData {
|
||||
pub object: String,
|
||||
pub embedding: Vec<f64>,
|
||||
pub index: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbeddingModel {
|
||||
client: Client,
|
||||
pub model: String,
|
||||
ndims: usize,
|
||||
}
|
||||
|
||||
impl embeddings::EmbeddingModel for EmbeddingModel {
|
||||
const MAX_DOCUMENTS: usize = 1024;
|
||||
|
||||
fn ndims(&self) -> usize {
|
||||
self.ndims
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn embed_texts(
|
||||
&self,
|
||||
documents: impl IntoIterator<Item = String>,
|
||||
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
|
||||
let documents = documents.into_iter().collect::<Vec<_>>();
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/embeddings")
|
||||
.json(&json!({
|
||||
"model": self.model,
|
||||
"input": documents,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
|
||||
ApiResponse::Ok(response) => {
|
||||
tracing::info!(target: "rig",
|
||||
"OpenAI embedding token usage: {}",
|
||||
response.usage
|
||||
);
|
||||
|
||||
if response.data.len() != documents.len() {
|
||||
return Err(EmbeddingError::ResponseError(
|
||||
"Response data length does not match input length".into(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.zip(documents.into_iter())
|
||||
.map(|(embedding, document)| embeddings::Embedding {
|
||||
document,
|
||||
vec: embedding.embedding,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
|
||||
}
|
||||
} else {
|
||||
Err(EmbeddingError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn new(client: Client, model: &str, ndims: usize) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
use crate::image_generation;
|
||||
use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
|
||||
use crate::providers::openai::{ApiResponse, Client};
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use base64::Engine;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Image Generation API
|
||||
// ================================================================
|
||||
pub const DALL_E_2: &str = "dall-e-2";
|
||||
pub const DALL_E_3: &str = "dall-e-3";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ImageGenerationData {
|
||||
pub b64_json: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ImageGenerationResponse {
|
||||
pub created: i32,
|
||||
pub data: Vec<ImageGenerationData>,
|
||||
}
|
||||
|
||||
impl TryFrom<ImageGenerationResponse>
|
||||
for image_generation::ImageGenerationResponse<ImageGenerationResponse>
|
||||
{
|
||||
type Error = ImageGenerationError;
|
||||
|
||||
fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
|
||||
let b64_json = value.data[0].b64_json.clone();
|
||||
|
||||
let bytes = BASE64_STANDARD
|
||||
.decode(&b64_json)
|
||||
.expect("Failed to decode b64");
|
||||
|
||||
Ok(image_generation::ImageGenerationResponse {
|
||||
image: bytes,
|
||||
response: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ImageGenerationModel {
|
||||
client: Client,
|
||||
/// Name of the model (e.g.: dall-e-2)
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl ImageGenerationModel {
|
||||
pub(crate) fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl image_generation::ImageGenerationModel for ImageGenerationModel {
|
||||
type Response = ImageGenerationResponse;
|
||||
|
||||
async fn image_generation(
|
||||
&self,
|
||||
generation_request: ImageGenerationRequest,
|
||||
) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
|
||||
{
|
||||
let request = json!({
|
||||
"model": self.model,
|
||||
"prompt": generation_request.prompt,
|
||||
"size": format!("{}x{}", generation_request.width, generation_request.height),
|
||||
"response_format": "b64_json"
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("/images/generations")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(ImageGenerationError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
let t = response.text().await?;
|
||||
|
||||
match serde_json::from_str::<ApiResponse<ImageGenerationResponse>>(&t)? {
|
||||
ApiResponse::Ok(response) => response.try_into(),
|
||||
ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
//! OpenAI API client and Rig integration
|
||||
//!
|
||||
//! # Example
|
||||
//! ```
|
||||
//! use rig::providers::openai;
|
||||
//!
|
||||
//! let client = openai::Client::new("YOUR_API_KEY");
|
||||
//!
|
||||
//! let gpt4o = client.completion_model(openai::GPT_4O);
|
||||
//! ```
|
||||
pub mod client;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
pub mod image_generation;
|
||||
pub mod streaming;
|
||||
pub mod transcription;
|
||||
|
||||
pub use client::*;
|
||||
pub use completion::*;
|
||||
pub use embedding::*;
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
pub use image_generation::*;
|
||||
pub use streaming::*;
|
||||
pub use transcription::*;
|
|
@ -0,0 +1,184 @@
|
|||
use super::completion::CompletionModel;
|
||||
use crate::completion::{CompletionError, CompletionRequest};
|
||||
use crate::json_utils;
|
||||
use crate::json_utils::merge;
|
||||
use crate::streaming;
|
||||
use crate::streaming::{StreamingCompletionModel, StreamingResult};
|
||||
use async_stream::stream;
|
||||
use futures::StreamExt;
|
||||
use reqwest::RequestBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Completion Streaming API
|
||||
// ================================================================
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StreamingFunction {
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StreamingToolCall {
|
||||
pub index: usize,
|
||||
pub function: StreamingFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingDelta {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
|
||||
tool_calls: Vec<StreamingToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingChoice {
|
||||
delta: StreamingDelta,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamingCompletionResponse {
|
||||
choices: Vec<StreamingChoice>,
|
||||
}
|
||||
|
||||
impl StreamingCompletionModel for CompletionModel {
|
||||
async fn stream(
|
||||
&self,
|
||||
completion_request: CompletionRequest,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
let mut request = self.create_completion_request(completion_request)?;
|
||||
request = merge(request, json!({"stream": true}));
|
||||
|
||||
let builder = self.client.post("/chat/completions").json(&request);
|
||||
send_compatible_streaming_request(builder).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_compatible_streaming_request(
|
||||
request_builder: RequestBuilder,
|
||||
) -> Result<StreamingResult, CompletionError> {
|
||||
let response = request_builder.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(CompletionError::ProviderError(format!(
|
||||
"{}: {}",
|
||||
response.status(),
|
||||
response.text().await?
|
||||
)));
|
||||
}
|
||||
|
||||
// Handle OpenAI Compatible SSE chunks
|
||||
Ok(Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
let mut partial_data = None;
|
||||
let mut calls: HashMap<usize, (String, String)> = HashMap::new();
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::from(e));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match String::from_utf8(chunk.to_vec()) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
yield Err(CompletionError::ResponseError(e.to_string()));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
for line in text.lines() {
|
||||
let mut line = line.to_string();
|
||||
|
||||
|
||||
|
||||
// If there was a remaining part, concat with current line
|
||||
if partial_data.is_some() {
|
||||
line = format!("{}{}", partial_data.unwrap(), line);
|
||||
partial_data = None;
|
||||
}
|
||||
// Otherwise full data line
|
||||
else {
|
||||
let Some(data) = line.strip_prefix("data: ") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Partial data, split somewhere in the middle
|
||||
if !line.ends_with("}") {
|
||||
partial_data = Some(data.to_string());
|
||||
} else {
|
||||
line = data.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
|
||||
|
||||
let Ok(data) = data else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let choice = data.choices.first().expect("Should have at least one choice");
|
||||
|
||||
let delta = &choice.delta;
|
||||
|
||||
if !delta.tool_calls.is_empty() {
|
||||
for tool_call in &delta.tool_calls {
|
||||
let function = tool_call.function.clone();
|
||||
|
||||
// Start of tool call
|
||||
// name: Some(String)
|
||||
// arguments: None
|
||||
if function.name.is_some() && function.arguments.is_empty() {
|
||||
calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
|
||||
}
|
||||
// Part of tool call
|
||||
// name: None
|
||||
// arguments: Some(String)
|
||||
else if function.name.is_none() && !function.arguments.is_empty() {
|
||||
let Some((name, arguments)) = calls.get(&tool_call.index) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let new_arguments = &tool_call.function.arguments;
|
||||
let arguments = format!("{}{}", arguments, new_arguments);
|
||||
|
||||
calls.insert(tool_call.index, (name.clone(), arguments));
|
||||
}
|
||||
// Entire tool call
|
||||
else {
|
||||
let name = function.name.unwrap();
|
||||
let arguments = function.arguments;
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(content) = &choice.delta.content {
|
||||
yield Ok(streaming::StreamingChoice::Message(content.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (_, (name, arguments)) in calls {
|
||||
let Ok(arguments) = serde_json::from_str(&arguments) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
|
||||
}
|
||||
}))
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
use crate::providers::openai::{ApiResponse, Client};
|
||||
use crate::transcription;
|
||||
use crate::transcription::TranscriptionError;
|
||||
use reqwest::multipart::Part;
|
||||
use serde::Deserialize;
|
||||
|
||||
// ================================================================
|
||||
// OpenAI Transcription API
|
||||
// ================================================================
|
||||
pub const WHISPER_1: &str = "whisper-1";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct TranscriptionResponse {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl TryFrom<TranscriptionResponse>
|
||||
for transcription::TranscriptionResponse<TranscriptionResponse>
|
||||
{
|
||||
type Error = TranscriptionError;
|
||||
|
||||
fn try_from(value: TranscriptionResponse) -> Result<Self, Self::Error> {
|
||||
Ok(transcription::TranscriptionResponse {
|
||||
text: value.text.clone(),
|
||||
response: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TranscriptionModel {
|
||||
client: Client,
|
||||
/// Name of the model (e.g.: gpt-3.5-turbo-1106)
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl TranscriptionModel {
|
||||
pub fn new(client: Client, model: &str) -> Self {
|
||||
Self {
|
||||
client,
|
||||
model: model.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl transcription::TranscriptionModel for TranscriptionModel {
|
||||
type Response = TranscriptionResponse;
|
||||
|
||||
#[cfg_attr(feature = "worker", worker::send)]
|
||||
async fn transcription(
|
||||
&self,
|
||||
request: transcription::TranscriptionRequest,
|
||||
) -> Result<
|
||||
transcription::TranscriptionResponse<Self::Response>,
|
||||
transcription::TranscriptionError,
|
||||
> {
|
||||
let data = request.data;
|
||||
|
||||
let mut body = reqwest::multipart::Form::new()
|
||||
.text("model", self.model.clone())
|
||||
.text("language", request.language)
|
||||
.part(
|
||||
"file",
|
||||
Part::bytes(data).file_name(request.filename.clone()),
|
||||
);
|
||||
|
||||
if let Some(prompt) = request.prompt {
|
||||
body = body.text("prompt", prompt.clone());
|
||||
}
|
||||
|
||||
if let Some(ref temperature) = request.temperature {
|
||||
body = body.text("temperature", temperature.to_string());
|
||||
}
|
||||
|
||||
if let Some(ref additional_params) = request.additional_params {
|
||||
for (key, value) in additional_params
|
||||
.as_object()
|
||||
.expect("Additional Parameters to OpenAI Transcription should be a map")
|
||||
{
|
||||
body = body.text(key.to_owned(), value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("audio/transcriptions")
|
||||
.multipart(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
match response
|
||||
.json::<ApiResponse<TranscriptionResponse>>()
|
||||
.await?
|
||||
{
|
||||
ApiResponse::Ok(response) => response.try_into(),
|
||||
ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
|
||||
api_error_response.message,
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
Err(TranscriptionError::ProviderError(response.text().await?))
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue