* init t5 wasm model

* split workers for each model

* clean up

* add some ui

* readme

* index

* typo

* remove cache param, clear_kv_cache

* add max_length as param

* add model tasks option to ui

* add method to load quantized gguf from buffer

* Add quantized wasm module

* add quantized models to UI, dynamic import wasms

* link to quantized

* fix copy

* fix ModelEncoder

* fix README.md
This commit is contained in:
Radamés Ajna 2023-09-22 07:31:10 -07:00 committed by GitHub
parent 8601537e31
commit 19e52e5007
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1131 additions and 0 deletions

View File

@ -12,6 +12,7 @@ members = [
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/t5",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"

View File

@ -30,6 +30,21 @@ impl VarBuilder {
})
}
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: Device::Cpu,
})
}
fn pp<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());

View File

@ -0,0 +1,33 @@
[package]
name = "candle-wasm-example-t5"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.2.2" }
candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
log = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
safetensors = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2", features = ["js"] }
gloo = "0.8"
js-sys = "0.3.64"
wasm-bindgen = "0.2.87"
serde-wasm-bindgen = "0.6.0"

View File

@ -0,0 +1,32 @@
## Running T5 with Candle and WASM
Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m.js";
```
For the quantized version, we need to import the quantized module:
```js
import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m-quantized.js";
```
The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/index.html` in your browser.

View File

@ -0,0 +1,93 @@
//load Candle Bert Module wasm module
let init, ModelConditionalGeneration;
async function fetchArrayBuffer(url) {
const cacheName = "t5-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class ConditionalGeneration {
static instance = {};
static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
if (modelID.includes("quantized")) {
({ default: init, ModelConditionalGeneration } = await import(
"./build/m-quantized.js"
));
} else {
({ default: init, ModelConditionalGeneration } = await import(
"./build/m.js"
));
}
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new ModelConditionalGeneration(
weightsArrayU8,
tokenizerArrayU8,
configArrayU8
);
} else {
self.postMessage({ status: "ready", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}
self.addEventListener("message", async (event) => {
const { weightsURL, tokenizerURL, configURL, modelID, prompt, params } =
event.data;
let {
temperature = 0.0,
seed = 299792458,
repeat_penalty = 1.1,
repeat_last_n = 64,
top_p = 1,
} = { ...params };
try {
self.postMessage({
status: "ready",
message: "Starting T5 Conditional Generation",
});
const model = await ConditionalGeneration.getInstance(
weightsURL,
tokenizerURL,
configURL,
modelID
);
self.postMessage({
status: "decoding",
message: "Decoding Prompt",
});
const output = model.decode({
prompt,
temperature,
seed,
top_p,
repeat_penalty,
repeat_last_n,
});
self.postMessage({
status: "complete",
message: "complete",
output: output,
});
} catch (e) {
self.postMessage({ error: e });
}
});

View File

@ -0,0 +1,83 @@
//load Candle Bert Module wasm module
let init, ModelEncoder;
async function fetchArrayBuffer(url) {
const cacheName = "t5-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class Encoder {
static instance = {};
static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
if (modelID.includes("quantized")) {
({ default: init, ModelEncoder } = await import(
"./build/m-quantized.js"
));
} else {
({ default: init, ModelEncoder } = await import("./build/m.js"));
}
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new ModelEncoder(
weightsArrayU8,
tokenizerArrayU8,
configArrayU8
);
} else {
self.postMessage({ status: "ready", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}
self.addEventListener("message", async (event) => {
const {
weightsURL,
tokenizerURL,
configURL,
modelID,
sentences,
normalize_embeddings,
} = event.data;
try {
self.postMessage({ status: "ready", message: "Starting T5 Encoder" });
const model = await Encoder.getInstance(
weightsURL,
tokenizerURL,
configURL,
modelID
);
self.postMessage({
status: "encoding",
message: "Encoding Sentences",
});
const output = model.decode({
sentences: sentences,
normalize_embeddings: normalize_embeddings || true,
});
self.postMessage({
status: "complete",
message: "complete",
output: output,
});
} catch (e) {
self.postMessage({ error: e });
}
});

View File

@ -0,0 +1,3 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m-quantized.wasm --out-dir build --target web

View File

@ -0,0 +1,276 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle T5</title>
</head>
<body></body>
</html>
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
</style>
<style type="text/tailwindcss">
.link {
@apply underline hover:text-blue-500 hover:no-underline;
}
</style>
<script src="https://cdn.tailwindcss.com"></script>
<script type="module">
import {
getModelInfo,
MODELS,
extractEmbeddings,
generateText,
} from "./utils.js";
const t5ModelEncoderWorker = new Worker("./T5ModelEncoderWorker.js", {
type: "module",
});
const t5ModelConditionalGeneration = new Worker(
"./T5ModelConditionalGeneration.js",
{ type: "module" }
);
const formEl = document.querySelector("#form");
const modelEl = document.querySelector("#model");
const promptEl = document.querySelector("#prompt");
const temperatureEl = document.querySelector("#temperature");
const toppEL = document.querySelector("#top-p");
const repeatPenaltyEl = document.querySelector("#repeat_penalty");
const seedEl = document.querySelector("#seed");
const outputEl = document.querySelector("#output-generation");
const tasksEl = document.querySelector("#tasks");
let selectedTaskID = "";
document.addEventListener("DOMContentLoaded", () => {
for (const [id, model] of Object.entries(MODELS)) {
const option = document.createElement("option");
option.value = id;
option.innerText = `${id} (${model.size})`;
modelEl.appendChild(option);
}
populateTasks(modelEl.value);
modelEl.addEventListener("change", (e) => {
populateTasks(e.target.value);
});
tasksEl.addEventListener("change", (e) => {
const task = e.target.value;
const modelID = modelEl.value;
promptEl.value = MODELS[modelID].tasks[task].prefix;
selectedTaskID = task;
});
});
function populateTasks(modelID) {
const tasks = MODELS[modelID].tasks;
tasksEl.innerHTML = "";
for (const [task, params] of Object.entries(tasks)) {
const div = document.createElement("div");
div.innerHTML = `
<input
type="radio"
name="task"
id="${task}"
class="font-light cursor-pointer"
value="${task}" />
<label for="${task}" class="cursor-pointer">
${params.prefix}
</label>
`;
tasksEl.appendChild(div);
}
selectedTaskID = Object.keys(tasks)[0];
tasksEl.querySelector(`#${selectedTaskID}`).checked = true;
}
form.addEventListener("submit", (e) => {
e.preventDefault();
const promptText = promptEl.value;
const modelID = modelEl.value;
const { modelURL, configURL, tokenizerURL, maxLength } = getModelInfo(
modelID,
selectedTaskID
);
const params = {
temperature: Number(temperatureEl.value),
top_p: Number(toppEL.value),
repetition_penalty: Number(repeatPenaltyEl.value),
seed: BigInt(seedEl.value),
max_length: maxLength,
};
generateText(
t5ModelConditionalGeneration,
modelURL,
tokenizerURL,
configURL,
modelID,
promptText,
params,
(status) => {
if (status.status === "loading") {
outputEl.innerText = "Loading model...";
}
if (status.status === "decoding") {
outputEl.innerText = "Generating...";
}
}
).then(({ output }) => {
outputEl.innerText = output.generation;
});
});
</script>
</head>
<body class="container max-w-4xl mx-auto p-4">
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle T5 Transformer</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
This demo showcase Text-To-Text Transfer Transformer (<a
href="https://blog.research.google/2020/02/exploring-transfer-learning-with-t5.html"
target="_blank"
class="link"
>T5</a
>) models right in your browser, thanks to
<a
href="https://github.com/huggingface/candle/"
target="_blank"
class="link">
Candle
</a>
ML framework and rust/wasm. You can choose from a range of available
models, including
<a
href="https://huggingface.co/t5-small"
target="_blank"
class="link">
t5-small</a
>,
<a href="https://huggingface.co/t5-base" target="_blank" class="link"
>t5-base</a
>,
<a
href="https://huggingface.co/google/flan-t5-small"
target="_blank"
class="link"
>flan-t5-small</a
>
and several t5
<a
href="https://huggingface.co/lmz/candle-quantized-t5/tree/main"
target="_blank"
class="link">
t5 quantized gguf</a
>.
</p>
</div>
<div>
<label for="model" class="font-medium">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light"></select>
</div>
<div>
<h3 class="font-medium">Task Prefix:</h3>
<form id="tasks" class="flex flex-col gap-1 my-2"></form>
</div>
<form
id="form"
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
<input type="submit" hidden />
<input
type="text"
id="prompt"
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
placeholder="Add prompt here, e.g. 'translate English to German: Today I'm going to eat Ice Cream'"
value="translate English to German: Today I'm going to eat Ice Cream" />
<button
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
Run
</button>
</form>
<div class="grid grid-cols-3 max-w-md items-center gap-3">
<label class="text-sm font-medium" for="temperature">Temperature</label>
<input
type="range"
id="temperature"
name="temperature"
min="0"
max="2"
step="0.01"
value="0.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
0.00</output
>
<label class="text-sm font-medium" for="top-p">Top-p</label>
<input
type="range"
id="top-p"
name="top-p"
min="0"
max="1"
step="0.01"
value="1.00"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
1.00</output
>
<label class="text-sm font-medium" for="repeat_penalty"
>Repeat Penalty</label
>
<input
type="range"
id="repeat_penalty"
name="repeat_penalty"
min="-2"
max="2"
step="0.01"
value="1.10"
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
<output
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
>1.10</output
>
<label class="text-sm font-medium" for="seed">Seed</label>
<input
type="number"
id="seed"
name="seed"
value="299792458"
class="font-light border border-gray-700 text-right rounded-md p-2" />
<button
id="run"
onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
Rand
</button>
</div>
<div>
<h3 class="font-medium">Generation:</h3>
<div
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2 text-lg">
<p id="output-generation" class="grid-rows-2">No output yet</p>
</div>
</div>
</main>
</body>
</html>

View File

@ -0,0 +1,205 @@
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
pub use candle_transformers::models::quantized_t5::{
Config, T5EncoderModel, T5ForConditionalGeneration, VarBuilder,
};
use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct ModelEncoder {
model: T5EncoderModel,
tokenizer: Tokenizer,
}
#[wasm_bindgen]
pub struct ModelConditionalGeneration {
model: T5ForConditionalGeneration,
tokenizer: Tokenizer,
config: Config,
}
#[wasm_bindgen]
impl ModelConditionalGeneration {
#[wasm_bindgen(constructor)]
pub fn load(
weights: Vec<u8>,
tokenizer: Vec<u8>,
config: Vec<u8>,
) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights)?;
let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let model = T5ForConditionalGeneration::load(vb, &config)?;
config.use_cache = false;
Ok(Self {
model,
tokenizer,
config,
})
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let device = &Device::Cpu;
self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt;
let repeat_penalty = input.repeat_penalty;
let repeat_last_n = input.repeat_last_n;
let seed = input.seed;
let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512);
let temperature = if input.temperature <= 0. {
None
} else {
Some(input.temperature)
};
let top_p = if input.top_p <= 0. || input.top_p >= 1. {
None
} else {
Some(input.top_p)
};
let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p);
let tokens = self
.tokenizer
.encode(prompt, true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let encoder_output = self.model.encode(&input_token_ids)?;
let mut decoded = String::new();
for index in 0.. {
if output_token_ids.len() > max_length {
break;
}
let decoder_token_ids = if index == 0 {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = self
.model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == self.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = self.tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
decoded += &text;
}
}
Ok(serde_wasm_bindgen::to_value(
&ConditionalGenerationOutput {
generation: decoded,
},
)?)
}
}
#[wasm_bindgen]
impl ModelEncoder {
#[wasm_bindgen(constructor)]
pub fn load(
weights: Vec<u8>,
tokenizer: Vec<u8>,
config: Vec<u8>,
) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights)?;
let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let model = T5EncoderModel::load(vb, &config)?;
Ok(Self { model, tokenizer })
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let device = &Device::Cpu;
let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
self.model.clear_kv_cache();
let sentences = input.sentences;
let normalize_embeddings = input.normalize_embeddings;
let n_sentences = sentences.len();
let mut all_embeddings = Vec::with_capacity(n_sentences);
for sentence in sentences {
let tokens = self
.tokenizer
.encode(sentence, true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.model.forward(&token_ids)?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if normalize_embeddings {
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
} else {
embeddings
};
console_log!("{:?}", embeddings.shape());
all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?);
}
Ok(serde_wasm_bindgen::to_value(&DecoderOutput {
embeddings: all_embeddings,
})?)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ConditionalGenerationOutput {
generation: String,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct DecoderOutput {
embeddings: Vec<Vec<f32>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct DecoderParams {
sentences: Vec<String>,
normalize_embeddings: bool,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ConditionalGenerationParams {
prompt: String,
temperature: f64,
seed: u64,
top_p: f64,
repeat_penalty: f32,
repeat_last_n: usize,
max_length: Option<usize>,
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,206 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
pub use candle_transformers::models::t5::{Config, T5EncoderModel, T5ForConditionalGeneration};
use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct ModelEncoder {
model: T5EncoderModel,
tokenizer: Tokenizer,
}
#[wasm_bindgen]
pub struct ModelConditionalGeneration {
model: T5ForConditionalGeneration,
tokenizer: Tokenizer,
config: Config,
}
#[wasm_bindgen]
impl ModelConditionalGeneration {
#[wasm_bindgen(constructor)]
pub fn load(
weights: Vec<u8>,
tokenizer: Vec<u8>,
config: Vec<u8>,
) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let model = T5ForConditionalGeneration::load(vb, &config)?;
config.use_cache = false;
Ok(Self {
model,
tokenizer,
config,
})
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let device = &Device::Cpu;
self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt;
let repeat_penalty = input.repeat_penalty;
let repeat_last_n = input.repeat_last_n;
let seed = input.seed;
let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512);
let temperature = if input.temperature <= 0. {
None
} else {
Some(input.temperature)
};
let top_p = if input.top_p <= 0. || input.top_p >= 1. {
None
} else {
Some(input.top_p)
};
let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p);
let tokens = self
.tokenizer
.encode(prompt, true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let encoder_output = self.model.encode(&input_token_ids)?;
let mut decoded = String::new();
for index in 0.. {
if output_token_ids.len() > max_length {
break;
}
let decoder_token_ids = if index == 0 {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = self
.model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == self.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = self.tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
decoded += &text;
}
}
Ok(serde_wasm_bindgen::to_value(
&ConditionalGenerationOutput {
generation: decoded,
},
)?)
}
}
#[wasm_bindgen]
impl ModelEncoder {
#[wasm_bindgen(constructor)]
pub fn load(
weights: Vec<u8>,
tokenizer: Vec<u8>,
config: Vec<u8>,
) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = &Device::Cpu;
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let model = T5EncoderModel::load(vb, &config)?;
Ok(Self { model, tokenizer })
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let device = &Device::Cpu;
let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
self.model.clear_kv_cache();
let sentences = input.sentences;
let normalize_embeddings = input.normalize_embeddings;
let n_sentences = sentences.len();
let mut all_embeddings = Vec::with_capacity(n_sentences);
for sentence in sentences {
let tokens = self
.tokenizer
.encode(sentence, true)
.map_err(|m| JsError::new(&m.to_string()))?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let embeddings = self.model.forward(&token_ids)?;
console_log!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if normalize_embeddings {
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
} else {
embeddings
};
console_log!("{:?}", embeddings.shape());
all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?);
}
Ok(serde_wasm_bindgen::to_value(&DecoderOutput {
embeddings: all_embeddings,
})?)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ConditionalGenerationOutput {
generation: String,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct DecoderOutput {
embeddings: Vec<Vec<f32>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct DecoderParams {
sentences: Vec<String>,
normalize_embeddings: bool,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ConditionalGenerationParams {
prompt: String,
temperature: f64,
seed: u64,
top_p: f64,
repeat_penalty: f32,
repeat_last_n: usize,
max_length: Option<usize>,
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,16 @@
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
}

View File

@ -0,0 +1,168 @@
export async function extractEmbeddings(
worker,
weightsURL,
tokenizerURL,
configURL,
modelID,
sentences,
updateStatus,
normalize_embeddings = true
) {
return new Promise((resolve, reject) => {
worker.postMessage({
weightsURL,
tokenizerURL,
configURL,
modelID,
sentences,
normalize_embeddings,
});
function messageHandler(event) {
if ("error" in event.data) {
worker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
worker.removeEventListener("message", messageHandler);
resolve(event.data);
}
if (updateStatus) updateStatus(event.data);
}
worker.addEventListener("message", messageHandler);
});
}
export async function generateText(
worker,
weightsURL,
tokenizerURL,
configURL,
modelID,
prompt,
params,
updateStatus
) {
return new Promise((resolve, reject) => {
worker.postMessage({
weightsURL,
tokenizerURL,
configURL,
modelID,
prompt,
params,
});
function messageHandler(event) {
if ("error" in event.data) {
worker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
worker.removeEventListener("message", messageHandler);
resolve(event.data);
}
if (updateStatus) updateStatus(event.data);
}
worker.addEventListener("message", messageHandler);
});
}
export const MODELS = {
t5_small_quantized: {
size: "102 MB",
base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
model: "model.gguf",
tokenizer: "tokenizer.json",
config: "config.json",
tasks: {
translation_en_to_de: {
prefix: "translate English to German: ",
max_length: 300,
},
translation_en_to_fr: {
prefix: "translate English to French: ",
max_length: 300,
},
translation_en_to_ro: {
prefix: "translate English to Romanian: ",
max_length: 300,
},
summarization: { prefix: "summarize: ", max_length: 200 },
},
},
t5_small: {
size: "242 MB",
base_url: "https://huggingface.co/t5-small/resolve/main/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
tasks: {
translation_en_to_de: {
prefix: "translate English to German: ",
max_length: 300,
},
translation_en_to_fr: {
prefix: "translate English to French: ",
max_length: 300,
},
translation_en_to_ro: {
prefix: "translate English to Romanian: ",
max_length: 300,
},
summarization: { prefix: "summarize: ", max_length: 200 },
},
},
flan_t5_small: {
size: "308 MB",
base_url:
"https://huggingface.co/google/flan-t5-small/resolve/refs%2Fpr%2F14/",
model: "model.safetensors",
tokenizer: "tokenizer.json",
config: "config.json",
tasks: {
translation_en_to_de: {
prefix: "translate English to German: ",
max_length: 300,
},
translation_en_to_fr: {
prefix: "translate English to French: ",
max_length: 300,
},
translation_en_to_ro: {
prefix: "translate English to Romanian: ",
max_length: 300,
},
summarization: { prefix: "summarize: ", max_length: 200 },
},
},
flan_t5_base_quantized: {
size: "360 MB",
base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
model: "model-flan-t5-base.gguf",
tokenizer: "tokenizer.json",
config: "config-flan-t5-base.json",
tasks: {
translation_en_to_de: {
prefix: "translate English to German: ",
max_length: 300,
},
translation_en_to_fr: {
prefix: "translate English to French: ",
max_length: 300,
},
translation_en_to_ro: {
prefix: "translate English to Romanian: ",
max_length: 300,
},
summarization: { prefix: "summarize: ", max_length: 200 },
},
},
};
export function getModelInfo(id, taskID) {
const model = MODELS[id];
return {
modelURL: model.base_url + model.model,
configURL: model.base_url + model.config,
tokenizerURL: model.base_url + model.tokenizer,
maxLength: model.tasks[taskID].max_length,
};
}