Added new language pairs to marian-mt example. (#2860)
* added new language pairs to marian-mt * lint * seperated python code for converting tokenizers into its own file and and added a reqirements.txt for dependencies, updated instructions in readme and included python version * Cleanup. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
b4daa03e59
commit
d6db305829
|
@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
|
||||||
mountain. I cannot stay far from you any longer.</s>
|
mountain. I cannot stay far from you any longer.</s>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Changing model and language pairs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
|
||||||
|
|
||||||
|
你好,你好吗?
|
||||||
|
```
|
||||||
|
|
||||||
## Generating the tokenizer.json files
|
## Generating the tokenizer.json files
|
||||||
|
|
||||||
You can use the following script to generate the `tokenizer.json` config files
|
The tokenizer for each `marian-mt` model was trained independently,
|
||||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
meaning each new model needs unique tokenizer encoders and decoders.
|
||||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
|
||||||
directory.
|
the `tokenizer.json` config files from the hf-hub repos.
|
||||||
|
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
|
||||||
```python
|
to be installed, and has only been tested for `python 3.12.7`.
|
||||||
from convert_slow_tokenizer import MarianConverter
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
|
||||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
|
||||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
|
||||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
|
||||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
|
||||||
```
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -20,6 +20,22 @@ enum Which {
|
||||||
Big,
|
Big,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum LanguagePair {
|
||||||
|
#[value(name = "fr-en")]
|
||||||
|
FrEn,
|
||||||
|
#[value(name = "en-zh")]
|
||||||
|
EnZh,
|
||||||
|
#[value(name = "en-hi")]
|
||||||
|
EnHi,
|
||||||
|
#[value(name = "en-es")]
|
||||||
|
EnEs,
|
||||||
|
#[value(name = "en-fr")]
|
||||||
|
EnFr,
|
||||||
|
#[value(name = "en-ru")]
|
||||||
|
EnRu,
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Maybe add support for the conditional prompt.
|
// TODO: Maybe add support for the conditional prompt.
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
@ -36,6 +52,10 @@ struct Args {
|
||||||
#[arg(long, default_value = "big")]
|
#[arg(long, default_value = "big")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
// Choose which language pair to use
|
||||||
|
#[arg(long, default_value = "fr-en")]
|
||||||
|
language_pair: LanguagePair,
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config = match args.which {
|
let config = match (args.which, args.language_pair) {
|
||||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
|
||||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
|
||||||
|
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
|
||||||
|
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
|
||||||
|
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
|
||||||
|
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
|
||||||
|
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
|
||||||
|
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
|
||||||
|
};
|
||||||
|
let tokenizer_default_repo = match args.language_pair {
|
||||||
|
LanguagePair::FrEn => "lmz/candle-marian",
|
||||||
|
LanguagePair::EnZh
|
||||||
|
| LanguagePair::EnHi
|
||||||
|
| LanguagePair::EnEs
|
||||||
|
| LanguagePair::EnFr
|
||||||
|
| LanguagePair::EnRu => "KeighBee/candle-marian",
|
||||||
};
|
};
|
||||||
let tokenizer = {
|
let tokenizer = {
|
||||||
let tokenizer = match args.tokenizer {
|
let tokenizer = match args.tokenizer {
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let filename = match (args.which, args.language_pair) {
|
||||||
Which::Base => "tokenizer-marian-base-fr.json",
|
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
|
||||||
Which::Big => "tokenizer-marian-fr.json",
|
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
|
||||||
|
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Api::new()?
|
Api::new()?
|
||||||
.model("lmz/candle-marian".to_string())
|
.model(tokenizer_default_repo.to_string())
|
||||||
.get(name)?
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
|
||||||
let tokenizer = match args.tokenizer_dec {
|
let tokenizer = match args.tokenizer_dec {
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let filename = match (args.which, args.language_pair) {
|
||||||
Which::Base => "tokenizer-marian-base-en.json",
|
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
|
||||||
Which::Big => "tokenizer-marian-en.json",
|
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
|
||||||
|
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
|
||||||
|
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
|
||||||
|
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
|
||||||
|
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Api::new()?
|
Api::new()?
|
||||||
.model("lmz/candle-marian".to_string())
|
.model(tokenizer_default_repo.to_string())
|
||||||
.get(name)?
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
|
||||||
let vb = {
|
let vb = {
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => match args.which {
|
None => {
|
||||||
Which::Base => Api::new()?
|
let api = Api::new()?;
|
||||||
.repo(hf_hub::Repo::with_revision(
|
let api = match (args.which, args.language_pair) {
|
||||||
|
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
|
||||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||||
hf_hub::RepoType::Model,
|
hf_hub::RepoType::Model,
|
||||||
"refs/pr/4".to_string(),
|
"refs/pr/4".to_string(),
|
||||||
))
|
)),
|
||||||
.get("model.safetensors")?,
|
(Which::Big, LanguagePair::FrEn) => {
|
||||||
Which::Big => Api::new()?
|
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
}
|
||||||
.get("model.safetensors")?,
|
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
|
||||||
},
|
"Helsinki-NLP/opus-mt-en-zh".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/13".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-hi".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/3".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-es".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/4".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-fr".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/9".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-ru".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/7".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
};
|
};
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
|
||||||
|
|
||||||
|
class MarianConverter(SpmConverter):
|
||||||
|
def __init__(self, *args, index: int = 0):
|
||||||
|
requires_backends(self, "protobuf")
|
||||||
|
|
||||||
|
super(SpmConverter, self).__init__(*args)
|
||||||
|
|
||||||
|
# from .utils import sentencepiece_model_pb2 as model_pb2
|
||||||
|
model_pb2 = import_protobuf()
|
||||||
|
|
||||||
|
m = model_pb2.ModelProto()
|
||||||
|
print(self.original_tokenizer.spm_files)
|
||||||
|
with open(self.original_tokenizer.spm_files[index], "rb") as f:
|
||||||
|
m.ParseFromString(f.read())
|
||||||
|
self.proto = m
|
||||||
|
print(self.original_tokenizer)
|
||||||
|
#with open(self.original_tokenizer.vocab_path, "r") as f:
|
||||||
|
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
|
||||||
|
with open(dir_path / "vocab.json", "r") as f:
|
||||||
|
import json
|
||||||
|
self._vocab = json.load(f)
|
||||||
|
|
||||||
|
if self.proto.trainer_spec.byte_fallback:
|
||||||
|
if not getattr(self, "handle_byte_fallback", None):
|
||||||
|
warnings.warn(
|
||||||
|
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||||
|
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||||
|
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||||
|
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||||
|
)
|
||||||
|
|
||||||
|
def vocab(self, proto):
|
||||||
|
vocab_size = max(self._vocab.values()) + 1
|
||||||
|
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
|
||||||
|
for piece in proto.pieces:
|
||||||
|
try:
|
||||||
|
index = self._vocab[piece.piece]
|
||||||
|
except Exception:
|
||||||
|
print(f"Ignored missing piece {piece.piece}")
|
||||||
|
vocab[index] = (piece.piece, piece.score)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||||
|
fast_tokenizer.save("tokenizer-marian-base-fr.json")
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||||
|
fast_tokenizer.save("tokenizer-marian-base-en.json")
|
|
@ -0,0 +1,22 @@
|
||||||
|
certifi==2025.1.31
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
filelock==3.18.0
|
||||||
|
fsspec==2025.3.2
|
||||||
|
huggingface-hub==0.30.1
|
||||||
|
idna==3.10
|
||||||
|
joblib==1.4.2
|
||||||
|
numpy==2.2.4
|
||||||
|
packaging==24.2
|
||||||
|
protobuf==6.30.2
|
||||||
|
pyyaml==6.0.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
|
sacremoses==0.1.1
|
||||||
|
safetensors==0.5.3
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
tokenizers==0.21.1
|
||||||
|
tqdm==4.67.1
|
||||||
|
transformers==4.50.3
|
||||||
|
typing-extensions==4.13.0
|
||||||
|
urllib3==2.3.0
|
|
@ -81,6 +81,126 @@ impl Config {
|
||||||
vocab_size: 59514,
|
vocab_size: 59514,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_zh() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 65000,
|
||||||
|
decoder_vocab_size: Some(65001),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 65000,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 65001,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_hi() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 61949,
|
||||||
|
decoder_vocab_size: Some(61950),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 61949,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 61950,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_es() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 65000,
|
||||||
|
decoder_vocab_size: Some(65001),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 65000,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 65001,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_fr() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 59513,
|
||||||
|
decoder_vocab_size: Some(59514),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 59513,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 59514,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_ru() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 62517,
|
||||||
|
decoder_vocab_size: Some(62518),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 62517,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 62518,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|
Loading…
Reference in New Issue