T5 Wasm (#918)
* init t5 wasm model * split workers for each model * clean up * add some ui * readme * index * typo * remove cache param, clear_kv_cache * add max_length as param * add model tasks option to ui * add method to load quantized gguf from buffer * Add quantized wasm module * add quantized models to UI, dynamic import wasms * link to quantized * fix copy * fix ModelEncoder * fix README.md
This commit is contained in:
parent
8601537e31
commit
19e52e5007
|
@ -12,6 +12,7 @@ members = [
|
|||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
"candle-wasm-examples/t5",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
|
|
@ -30,6 +30,21 @@ impl VarBuilder {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
})
|
||||
}
|
||||
|
||||
fn pp<S: ToString>(&self, s: S) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
[package]
|
||||
name = "candle-wasm-example-t5"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.2.2", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.2.2" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.2.2" }
|
||||
num-traits = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
log = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
gloo = "0.8"
|
||||
js-sys = "0.3.64"
|
||||
wasm-bindgen = "0.2.87"
|
||||
serde-wasm-bindgen = "0.6.0"
|
|
@ -0,0 +1,32 @@
|
|||
## Running T5 with Candle and WASM
|
||||
|
||||
Here, we provide two examples of how to run Bert using a Candle-compiled WASM binary and runtime.
|
||||
|
||||
### Vanilla JS and WebWorkers
|
||||
|
||||
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
|
||||
|
||||
```bash
|
||||
sh build-lib.sh
|
||||
```
|
||||
|
||||
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
|
||||
|
||||
```js
|
||||
import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m.js";
|
||||
```
|
||||
|
||||
For the quantized version, we need to import the quantized module:
|
||||
|
||||
```js
|
||||
import init, { ModelConditionalGeneration, ModelEncoder } from "./build/m-quantized.js";
|
||||
```
|
||||
|
||||
The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything.
|
||||
Finally, you can preview the example by running a local HTTP server. For example:
|
||||
|
||||
```bash
|
||||
python -m http.server
|
||||
```
|
||||
|
||||
Then open `http://localhost:8000/index.html` in your browser.
|
|
@ -0,0 +1,93 @@
|
|||
//load Candle Bert Module wasm module
|
||||
let init, ModelConditionalGeneration;
|
||||
|
||||
async function fetchArrayBuffer(url) {
|
||||
const cacheName = "t5-candle-cache";
|
||||
const cache = await caches.open(cacheName);
|
||||
const cachedResponse = await cache.match(url);
|
||||
if (cachedResponse) {
|
||||
const data = await cachedResponse.arrayBuffer();
|
||||
return new Uint8Array(data);
|
||||
}
|
||||
const res = await fetch(url, { cache: "force-cache" });
|
||||
cache.put(url, res.clone());
|
||||
return new Uint8Array(await res.arrayBuffer());
|
||||
}
|
||||
class ConditionalGeneration {
|
||||
static instance = {};
|
||||
|
||||
static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
|
||||
if (modelID.includes("quantized")) {
|
||||
({ default: init, ModelConditionalGeneration } = await import(
|
||||
"./build/m-quantized.js"
|
||||
));
|
||||
} else {
|
||||
({ default: init, ModelConditionalGeneration } = await import(
|
||||
"./build/m.js"
|
||||
));
|
||||
}
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
|
||||
await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
fetchArrayBuffer(configURL),
|
||||
]);
|
||||
|
||||
this.instance[modelID] = new ModelConditionalGeneration(
|
||||
weightsArrayU8,
|
||||
tokenizerArrayU8,
|
||||
configArrayU8
|
||||
);
|
||||
} else {
|
||||
self.postMessage({ status: "ready", message: "Model Already Loaded" });
|
||||
}
|
||||
return this.instance[modelID];
|
||||
}
|
||||
}
|
||||
|
||||
self.addEventListener("message", async (event) => {
|
||||
const { weightsURL, tokenizerURL, configURL, modelID, prompt, params } =
|
||||
event.data;
|
||||
let {
|
||||
temperature = 0.0,
|
||||
seed = 299792458,
|
||||
repeat_penalty = 1.1,
|
||||
repeat_last_n = 64,
|
||||
top_p = 1,
|
||||
} = { ...params };
|
||||
try {
|
||||
self.postMessage({
|
||||
status: "ready",
|
||||
message: "Starting T5 Conditional Generation",
|
||||
});
|
||||
const model = await ConditionalGeneration.getInstance(
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID
|
||||
);
|
||||
self.postMessage({
|
||||
status: "decoding",
|
||||
message: "Decoding Prompt",
|
||||
});
|
||||
const output = model.decode({
|
||||
prompt,
|
||||
temperature,
|
||||
seed,
|
||||
top_p,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
});
|
||||
self.postMessage({
|
||||
status: "complete",
|
||||
message: "complete",
|
||||
output: output,
|
||||
});
|
||||
} catch (e) {
|
||||
self.postMessage({ error: e });
|
||||
}
|
||||
});
|
|
@ -0,0 +1,83 @@
|
|||
//load Candle Bert Module wasm module
|
||||
let init, ModelEncoder;
|
||||
|
||||
async function fetchArrayBuffer(url) {
|
||||
const cacheName = "t5-candle-cache";
|
||||
const cache = await caches.open(cacheName);
|
||||
const cachedResponse = await cache.match(url);
|
||||
if (cachedResponse) {
|
||||
const data = await cachedResponse.arrayBuffer();
|
||||
return new Uint8Array(data);
|
||||
}
|
||||
const res = await fetch(url, { cache: "force-cache" });
|
||||
cache.put(url, res.clone());
|
||||
return new Uint8Array(await res.arrayBuffer());
|
||||
}
|
||||
class Encoder {
|
||||
static instance = {};
|
||||
|
||||
static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
|
||||
if (modelID.includes("quantized")) {
|
||||
({ default: init, ModelEncoder } = await import(
|
||||
"./build/m-quantized.js"
|
||||
));
|
||||
} else {
|
||||
({ default: init, ModelEncoder } = await import("./build/m.js"));
|
||||
}
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
|
||||
await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
fetchArrayBuffer(configURL),
|
||||
]);
|
||||
|
||||
this.instance[modelID] = new ModelEncoder(
|
||||
weightsArrayU8,
|
||||
tokenizerArrayU8,
|
||||
configArrayU8
|
||||
);
|
||||
} else {
|
||||
self.postMessage({ status: "ready", message: "Model Already Loaded" });
|
||||
}
|
||||
return this.instance[modelID];
|
||||
}
|
||||
}
|
||||
|
||||
self.addEventListener("message", async (event) => {
|
||||
const {
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
sentences,
|
||||
normalize_embeddings,
|
||||
} = event.data;
|
||||
try {
|
||||
self.postMessage({ status: "ready", message: "Starting T5 Encoder" });
|
||||
const model = await Encoder.getInstance(
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID
|
||||
);
|
||||
self.postMessage({
|
||||
status: "encoding",
|
||||
message: "Encoding Sentences",
|
||||
});
|
||||
const output = model.decode({
|
||||
sentences: sentences,
|
||||
normalize_embeddings: normalize_embeddings || true,
|
||||
});
|
||||
self.postMessage({
|
||||
status: "complete",
|
||||
message: "complete",
|
||||
output: output,
|
||||
});
|
||||
} catch (e) {
|
||||
self.postMessage({ error: e });
|
||||
}
|
||||
});
|
|
@ -0,0 +1,3 @@
|
|||
cargo build --target wasm32-unknown-unknown --release
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m-quantized.wasm --out-dir build --target web
|
|
@ -0,0 +1,276 @@
|
|||
<html>
|
||||
<head>
|
||||
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
|
||||
<title>Candle T5</title>
|
||||
</head>
|
||||
|
||||
<body></body>
|
||||
</html>
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<style>
|
||||
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||
|
||||
html,
|
||||
body {
|
||||
font-family: "Source Sans 3", sans-serif;
|
||||
}
|
||||
</style>
|
||||
<style type="text/tailwindcss">
|
||||
.link {
|
||||
@apply underline hover:text-blue-500 hover:no-underline;
|
||||
}
|
||||
</style>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script type="module">
|
||||
import {
|
||||
getModelInfo,
|
||||
MODELS,
|
||||
extractEmbeddings,
|
||||
generateText,
|
||||
} from "./utils.js";
|
||||
|
||||
const t5ModelEncoderWorker = new Worker("./T5ModelEncoderWorker.js", {
|
||||
type: "module",
|
||||
});
|
||||
const t5ModelConditionalGeneration = new Worker(
|
||||
"./T5ModelConditionalGeneration.js",
|
||||
{ type: "module" }
|
||||
);
|
||||
|
||||
const formEl = document.querySelector("#form");
|
||||
const modelEl = document.querySelector("#model");
|
||||
const promptEl = document.querySelector("#prompt");
|
||||
const temperatureEl = document.querySelector("#temperature");
|
||||
const toppEL = document.querySelector("#top-p");
|
||||
const repeatPenaltyEl = document.querySelector("#repeat_penalty");
|
||||
const seedEl = document.querySelector("#seed");
|
||||
const outputEl = document.querySelector("#output-generation");
|
||||
const tasksEl = document.querySelector("#tasks");
|
||||
let selectedTaskID = "";
|
||||
|
||||
document.addEventListener("DOMContentLoaded", () => {
|
||||
for (const [id, model] of Object.entries(MODELS)) {
|
||||
const option = document.createElement("option");
|
||||
option.value = id;
|
||||
option.innerText = `${id} (${model.size})`;
|
||||
modelEl.appendChild(option);
|
||||
}
|
||||
populateTasks(modelEl.value);
|
||||
modelEl.addEventListener("change", (e) => {
|
||||
populateTasks(e.target.value);
|
||||
});
|
||||
tasksEl.addEventListener("change", (e) => {
|
||||
const task = e.target.value;
|
||||
const modelID = modelEl.value;
|
||||
promptEl.value = MODELS[modelID].tasks[task].prefix;
|
||||
selectedTaskID = task;
|
||||
});
|
||||
});
|
||||
function populateTasks(modelID) {
|
||||
const tasks = MODELS[modelID].tasks;
|
||||
tasksEl.innerHTML = "";
|
||||
for (const [task, params] of Object.entries(tasks)) {
|
||||
const div = document.createElement("div");
|
||||
div.innerHTML = `
|
||||
<input
|
||||
type="radio"
|
||||
name="task"
|
||||
id="${task}"
|
||||
class="font-light cursor-pointer"
|
||||
value="${task}" />
|
||||
<label for="${task}" class="cursor-pointer">
|
||||
${params.prefix}
|
||||
</label>
|
||||
`;
|
||||
tasksEl.appendChild(div);
|
||||
}
|
||||
selectedTaskID = Object.keys(tasks)[0];
|
||||
tasksEl.querySelector(`#${selectedTaskID}`).checked = true;
|
||||
}
|
||||
form.addEventListener("submit", (e) => {
|
||||
e.preventDefault();
|
||||
|
||||
const promptText = promptEl.value;
|
||||
const modelID = modelEl.value;
|
||||
const { modelURL, configURL, tokenizerURL, maxLength } = getModelInfo(
|
||||
modelID,
|
||||
selectedTaskID
|
||||
);
|
||||
const params = {
|
||||
temperature: Number(temperatureEl.value),
|
||||
top_p: Number(toppEL.value),
|
||||
repetition_penalty: Number(repeatPenaltyEl.value),
|
||||
seed: BigInt(seedEl.value),
|
||||
max_length: maxLength,
|
||||
};
|
||||
generateText(
|
||||
t5ModelConditionalGeneration,
|
||||
modelURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
promptText,
|
||||
params,
|
||||
(status) => {
|
||||
if (status.status === "loading") {
|
||||
outputEl.innerText = "Loading model...";
|
||||
}
|
||||
if (status.status === "decoding") {
|
||||
outputEl.innerText = "Generating...";
|
||||
}
|
||||
}
|
||||
).then(({ output }) => {
|
||||
outputEl.innerText = output.generation;
|
||||
});
|
||||
});
|
||||
</script>
|
||||
</head>
|
||||
|
||||
<body class="container max-w-4xl mx-auto p-4">
|
||||
<main class="grid grid-cols-1 gap-8 relative">
|
||||
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
|
||||
<div>
|
||||
<h1 class="text-5xl font-bold">Candle T5 Transformer</h1>
|
||||
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
|
||||
<p class="max-w-lg">
|
||||
This demo showcase Text-To-Text Transfer Transformer (<a
|
||||
href="https://blog.research.google/2020/02/exploring-transfer-learning-with-t5.html"
|
||||
target="_blank"
|
||||
class="link"
|
||||
>T5</a
|
||||
>) models right in your browser, thanks to
|
||||
<a
|
||||
href="https://github.com/huggingface/candle/"
|
||||
target="_blank"
|
||||
class="link">
|
||||
Candle
|
||||
</a>
|
||||
ML framework and rust/wasm. You can choose from a range of available
|
||||
models, including
|
||||
<a
|
||||
href="https://huggingface.co/t5-small"
|
||||
target="_blank"
|
||||
class="link">
|
||||
t5-small</a
|
||||
>,
|
||||
<a href="https://huggingface.co/t5-base" target="_blank" class="link"
|
||||
>t5-base</a
|
||||
>,
|
||||
<a
|
||||
href="https://huggingface.co/google/flan-t5-small"
|
||||
target="_blank"
|
||||
class="link"
|
||||
>flan-t5-small</a
|
||||
>
|
||||
and several t5
|
||||
<a
|
||||
href="https://huggingface.co/lmz/candle-quantized-t5/tree/main"
|
||||
target="_blank"
|
||||
class="link">
|
||||
t5 quantized gguf</a
|
||||
>.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="model" class="font-medium">Models Options: </label>
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light"></select>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h3 class="font-medium">Task Prefix:</h3>
|
||||
<form id="tasks" class="flex flex-col gap-1 my-2"></form>
|
||||
</div>
|
||||
<form
|
||||
id="form"
|
||||
class="flex text-normal px-1 py-1 border border-gray-700 rounded-md items-center">
|
||||
<input type="submit" hidden />
|
||||
<input
|
||||
type="text"
|
||||
id="prompt"
|
||||
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
|
||||
placeholder="Add prompt here, e.g. 'translate English to German: Today I'm going to eat Ice Cream'"
|
||||
value="translate English to German: Today I'm going to eat Ice Cream" />
|
||||
<button
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-2 w-16 rounded disabled:bg-gray-300 disabled:cursor-not-allowed">
|
||||
Run
|
||||
</button>
|
||||
</form>
|
||||
<div class="grid grid-cols-3 max-w-md items-center gap-3">
|
||||
<label class="text-sm font-medium" for="temperature">Temperature</label>
|
||||
<input
|
||||
type="range"
|
||||
id="temperature"
|
||||
name="temperature"
|
||||
min="0"
|
||||
max="2"
|
||||
step="0.01"
|
||||
value="0.00"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
|
||||
0.00</output
|
||||
>
|
||||
<label class="text-sm font-medium" for="top-p">Top-p</label>
|
||||
<input
|
||||
type="range"
|
||||
id="top-p"
|
||||
name="top-p"
|
||||
min="0"
|
||||
max="1"
|
||||
step="0.01"
|
||||
value="1.00"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md">
|
||||
1.00</output
|
||||
>
|
||||
|
||||
<label class="text-sm font-medium" for="repeat_penalty"
|
||||
>Repeat Penalty</label
|
||||
>
|
||||
|
||||
<input
|
||||
type="range"
|
||||
id="repeat_penalty"
|
||||
name="repeat_penalty"
|
||||
min="-2"
|
||||
max="2"
|
||||
step="0.01"
|
||||
value="1.10"
|
||||
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)" />
|
||||
<output
|
||||
class="text-xs w-[50px] text-center font-light px-1 py-1 border border-gray-700 rounded-md"
|
||||
>1.10</output
|
||||
>
|
||||
<label class="text-sm font-medium" for="seed">Seed</label>
|
||||
<input
|
||||
type="number"
|
||||
id="seed"
|
||||
name="seed"
|
||||
value="299792458"
|
||||
class="font-light border border-gray-700 text-right rounded-md p-2" />
|
||||
<button
|
||||
id="run"
|
||||
onclick="document.querySelector('#seed').value = BigInt(Math.floor(Math.random() * 2**64-1))"
|
||||
class="bg-gray-700 hover:bg-gray-800 text-white font-normal py-1 w-[50px] rounded disabled:bg-gray-300 disabled:cursor-not-allowed text-sm">
|
||||
Rand
|
||||
</button>
|
||||
</div>
|
||||
<div>
|
||||
<h3 class="font-medium">Generation:</h3>
|
||||
<div
|
||||
class="min-h-[250px] bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2 text-lg">
|
||||
<p id="output-generation" class="grid-rows-2">No output yet</p>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,205 @@
|
|||
use candle::{Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
pub use candle_transformers::models::quantized_t5::{
|
||||
Config, T5EncoderModel, T5ForConditionalGeneration, VarBuilder,
|
||||
};
|
||||
|
||||
use candle_wasm_example_t5::console_log;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct ModelEncoder {
|
||||
model: T5EncoderModel,
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
|
||||
pub struct ModelConditionalGeneration {
|
||||
model: T5ForConditionalGeneration,
|
||||
tokenizer: Tokenizer,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ModelConditionalGeneration {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(
|
||||
weights: Vec<u8>,
|
||||
tokenizer: Vec<u8>,
|
||||
config: Vec<u8>,
|
||||
) -> Result<ModelConditionalGeneration, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let model = T5ForConditionalGeneration::load(vb, &config)?;
|
||||
config.use_cache = false;
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
config,
|
||||
})
|
||||
}
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: ConditionalGenerationParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let device = &Device::Cpu;
|
||||
self.model.clear_kv_cache();
|
||||
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
|
||||
let prompt = input.prompt;
|
||||
let repeat_penalty = input.repeat_penalty;
|
||||
let repeat_last_n = input.repeat_last_n;
|
||||
let seed = input.seed;
|
||||
let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512);
|
||||
let temperature = if input.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(input.temperature)
|
||||
};
|
||||
let top_p = if input.top_p <= 0. || input.top_p >= 1. {
|
||||
None
|
||||
} else {
|
||||
Some(input.top_p)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p);
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let encoder_output = self.model.encode(&input_token_ids)?;
|
||||
let mut decoded = String::new();
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > max_length {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = self
|
||||
.model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == self.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = self.tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
decoded += &text;
|
||||
}
|
||||
}
|
||||
Ok(serde_wasm_bindgen::to_value(
|
||||
&ConditionalGenerationOutput {
|
||||
generation: decoded,
|
||||
},
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ModelEncoder {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(
|
||||
weights: Vec<u8>,
|
||||
tokenizer: Vec<u8>,
|
||||
config: Vec<u8>,
|
||||
) -> Result<ModelEncoder, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
config.use_cache = false;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let model = T5EncoderModel::load(vb, &config)?;
|
||||
Ok(Self { model, tokenizer })
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let device = &Device::Cpu;
|
||||
let input: DecoderParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
self.model.clear_kv_cache();
|
||||
let sentences = input.sentences;
|
||||
let normalize_embeddings = input.normalize_embeddings;
|
||||
let n_sentences = sentences.len();
|
||||
let mut all_embeddings = Vec::with_capacity(n_sentences);
|
||||
for sentence in sentences {
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(sentence, true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let embeddings = self.model.forward(&token_ids)?;
|
||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if normalize_embeddings {
|
||||
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
console_log!("{:?}", embeddings.shape());
|
||||
all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?);
|
||||
}
|
||||
|
||||
Ok(serde_wasm_bindgen::to_value(&DecoderOutput {
|
||||
embeddings: all_embeddings,
|
||||
})?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct ConditionalGenerationOutput {
|
||||
generation: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct DecoderOutput {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct DecoderParams {
|
||||
sentences: Vec<String>,
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct ConditionalGenerationParams {
|
||||
prompt: String,
|
||||
temperature: f64,
|
||||
seed: u64,
|
||||
top_p: f64,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
max_length: Option<usize>,
|
||||
}
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
pub use candle_transformers::models::t5::{Config, T5EncoderModel, T5ForConditionalGeneration};
|
||||
use candle_wasm_example_t5::console_log;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
#[wasm_bindgen]
|
||||
pub struct ModelEncoder {
|
||||
model: T5EncoderModel,
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
|
||||
pub struct ModelConditionalGeneration {
|
||||
model: T5ForConditionalGeneration,
|
||||
tokenizer: Tokenizer,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ModelConditionalGeneration {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(
|
||||
weights: Vec<u8>,
|
||||
tokenizer: Vec<u8>,
|
||||
config: Vec<u8>,
|
||||
) -> Result<ModelConditionalGeneration, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let model = T5ForConditionalGeneration::load(vb, &config)?;
|
||||
config.use_cache = false;
|
||||
Ok(Self {
|
||||
model,
|
||||
tokenizer,
|
||||
config,
|
||||
})
|
||||
}
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: ConditionalGenerationParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let device = &Device::Cpu;
|
||||
self.model.clear_kv_cache();
|
||||
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
|
||||
let prompt = input.prompt;
|
||||
let repeat_penalty = input.repeat_penalty;
|
||||
let repeat_last_n = input.repeat_last_n;
|
||||
let seed = input.seed;
|
||||
let max_length = usize::clamp(input.max_length.unwrap_or(512), 0, 512);
|
||||
let temperature = if input.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(input.temperature)
|
||||
};
|
||||
let top_p = if input.top_p <= 0. || input.top_p >= 1. {
|
||||
None
|
||||
} else {
|
||||
Some(input.top_p)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(seed, temperature, top_p);
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let encoder_output = self.model.encode(&input_token_ids)?;
|
||||
let mut decoded = String::new();
|
||||
for index in 0.. {
|
||||
if output_token_ids.len() > max_length {
|
||||
break;
|
||||
}
|
||||
let decoder_token_ids = if index == 0 {
|
||||
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
|
||||
} else {
|
||||
let last_token = *output_token_ids.last().unwrap();
|
||||
Tensor::new(&[last_token], device)?.unsqueeze(0)?
|
||||
};
|
||||
let logits = self
|
||||
.model
|
||||
.decode(&decoder_token_ids, &encoder_output)?
|
||||
.squeeze(0)?;
|
||||
let logits = if repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = output_token_ids.len().saturating_sub(repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
repeat_penalty,
|
||||
&output_token_ids[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits)?;
|
||||
if next_token_id as usize == self.config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
output_token_ids.push(next_token_id);
|
||||
if let Some(text) = self.tokenizer.id_to_token(next_token_id) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
decoded += &text;
|
||||
}
|
||||
}
|
||||
Ok(serde_wasm_bindgen::to_value(
|
||||
&ConditionalGenerationOutput {
|
||||
generation: decoded,
|
||||
},
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ModelEncoder {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(
|
||||
weights: Vec<u8>,
|
||||
tokenizer: Vec<u8>,
|
||||
config: Vec<u8>,
|
||||
) -> Result<ModelEncoder, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let device = &Device::Cpu;
|
||||
let weights = safetensors::tensor::SafeTensors::deserialize(&weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, device);
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
config.use_cache = false;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let model = T5EncoderModel::load(vb, &config)?;
|
||||
Ok(Self { model, tokenizer })
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let device = &Device::Cpu;
|
||||
let input: DecoderParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
self.model.clear_kv_cache();
|
||||
let sentences = input.sentences;
|
||||
let normalize_embeddings = input.normalize_embeddings;
|
||||
let n_sentences = sentences.len();
|
||||
let mut all_embeddings = Vec::with_capacity(n_sentences);
|
||||
for sentence in sentences {
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.encode(sentence, true)
|
||||
.map_err(|m| JsError::new(&m.to_string()))?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let embeddings = self.model.forward(&token_ids)?;
|
||||
console_log!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if normalize_embeddings {
|
||||
embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
console_log!("{:?}", embeddings.shape());
|
||||
all_embeddings.push(embeddings.squeeze(0)?.to_vec1::<f32>()?);
|
||||
}
|
||||
|
||||
Ok(serde_wasm_bindgen::to_value(&DecoderOutput {
|
||||
embeddings: all_embeddings,
|
||||
})?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct ConditionalGenerationOutput {
|
||||
generation: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct DecoderOutput {
|
||||
embeddings: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct DecoderParams {
|
||||
sentences: Vec<String>,
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct ConditionalGenerationParams {
|
||||
prompt: String,
|
||||
temperature: f64,
|
||||
seed: u64,
|
||||
top_p: f64,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
max_length: Option<usize>,
|
||||
}
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
// Use `js_namespace` here to bind `console.log(..)` instead of just
|
||||
// `log(..)`
|
||||
#[wasm_bindgen(js_namespace = console)]
|
||||
pub fn log(s: &str);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
// Note that this is using the `log` function imported above during
|
||||
// `bare_bones`
|
||||
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
|
||||
}
|
|
@ -0,0 +1,168 @@
|
|||
export async function extractEmbeddings(
|
||||
worker,
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
sentences,
|
||||
updateStatus,
|
||||
normalize_embeddings = true
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
worker.postMessage({
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
sentences,
|
||||
normalize_embeddings,
|
||||
});
|
||||
function messageHandler(event) {
|
||||
if ("error" in event.data) {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
reject(new Error(event.data.error));
|
||||
}
|
||||
if (event.data.status === "complete") {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
resolve(event.data);
|
||||
}
|
||||
if (updateStatus) updateStatus(event.data);
|
||||
}
|
||||
worker.addEventListener("message", messageHandler);
|
||||
});
|
||||
}
|
||||
|
||||
export async function generateText(
|
||||
worker,
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
prompt,
|
||||
params,
|
||||
updateStatus
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
worker.postMessage({
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
prompt,
|
||||
params,
|
||||
});
|
||||
function messageHandler(event) {
|
||||
if ("error" in event.data) {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
reject(new Error(event.data.error));
|
||||
}
|
||||
if (event.data.status === "complete") {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
resolve(event.data);
|
||||
}
|
||||
if (updateStatus) updateStatus(event.data);
|
||||
}
|
||||
worker.addEventListener("message", messageHandler);
|
||||
});
|
||||
}
|
||||
export const MODELS = {
|
||||
t5_small_quantized: {
|
||||
size: "102 MB",
|
||||
base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
|
||||
model: "model.gguf",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
tasks: {
|
||||
translation_en_to_de: {
|
||||
prefix: "translate English to German: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_fr: {
|
||||
prefix: "translate English to French: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_ro: {
|
||||
prefix: "translate English to Romanian: ",
|
||||
max_length: 300,
|
||||
},
|
||||
summarization: { prefix: "summarize: ", max_length: 200 },
|
||||
},
|
||||
},
|
||||
t5_small: {
|
||||
size: "242 MB",
|
||||
base_url: "https://huggingface.co/t5-small/resolve/main/",
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
tasks: {
|
||||
translation_en_to_de: {
|
||||
prefix: "translate English to German: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_fr: {
|
||||
prefix: "translate English to French: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_ro: {
|
||||
prefix: "translate English to Romanian: ",
|
||||
max_length: 300,
|
||||
},
|
||||
summarization: { prefix: "summarize: ", max_length: 200 },
|
||||
},
|
||||
},
|
||||
flan_t5_small: {
|
||||
size: "308 MB",
|
||||
base_url:
|
||||
"https://huggingface.co/google/flan-t5-small/resolve/refs%2Fpr%2F14/",
|
||||
model: "model.safetensors",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config.json",
|
||||
tasks: {
|
||||
translation_en_to_de: {
|
||||
prefix: "translate English to German: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_fr: {
|
||||
prefix: "translate English to French: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_ro: {
|
||||
prefix: "translate English to Romanian: ",
|
||||
max_length: 300,
|
||||
},
|
||||
summarization: { prefix: "summarize: ", max_length: 200 },
|
||||
},
|
||||
},
|
||||
|
||||
flan_t5_base_quantized: {
|
||||
size: "360 MB",
|
||||
base_url: "https://huggingface.co/lmz/candle-quantized-t5/resolve/main/",
|
||||
model: "model-flan-t5-base.gguf",
|
||||
tokenizer: "tokenizer.json",
|
||||
config: "config-flan-t5-base.json",
|
||||
tasks: {
|
||||
translation_en_to_de: {
|
||||
prefix: "translate English to German: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_fr: {
|
||||
prefix: "translate English to French: ",
|
||||
max_length: 300,
|
||||
},
|
||||
translation_en_to_ro: {
|
||||
prefix: "translate English to Romanian: ",
|
||||
max_length: 300,
|
||||
},
|
||||
summarization: { prefix: "summarize: ", max_length: 200 },
|
||||
},
|
||||
},
|
||||
};
|
||||
export function getModelInfo(id, taskID) {
|
||||
const model = MODELS[id];
|
||||
return {
|
||||
modelURL: model.base_url + model.model,
|
||||
configURL: model.base_url + model.config,
|
||||
tokenizerURL: model.base_url + model.tokenizer,
|
||||
maxLength: model.tasks[taskID].max_length,
|
||||
};
|
||||
}
|
Loading…
Reference in New Issue