parent
236b29ff15
commit
31a1075f4b
|
@ -55,6 +55,10 @@ pub struct LSTMState {
|
|||
}
|
||||
|
||||
impl LSTMState {
|
||||
pub fn new(h: Tensor, c: Tensor) -> Self {
|
||||
LSTMState { h, c }
|
||||
}
|
||||
|
||||
/// The hidden state vector, which is also the output of the LSTM.
|
||||
pub fn h(&self) -> &Tensor {
|
||||
&self.h
|
||||
|
|
|
@ -66,6 +66,18 @@ impl Attr for GraphProto {
|
|||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Vec<String> {
|
||||
const TYPE: AttributeType = AttributeType::Strings;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
let mut ret = vec![];
|
||||
for bytes in attr.strings.iter() {
|
||||
let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?;
|
||||
ret.push(s);
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
|
@ -1310,6 +1322,233 @@ fn simple_eval_(
|
|||
.broadcast_add(&c.broadcast_mul(&beta)?)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LSTM" => {
|
||||
let direction = get_attr_opt(node, "direction")?.unwrap_or("forward");
|
||||
if direction != "forward" {
|
||||
bail!("LSTM currently only supports direction == \"forward\"");
|
||||
}
|
||||
let num_directions = if direction == "bidirectional" { 2 } else { 1 };
|
||||
let hidden_size: i64 = get_attr(node, "hidden_size").copied()?;
|
||||
let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0);
|
||||
if input_forget != 0 {
|
||||
bail!("LSTM currently only supports input_forget == 0");
|
||||
}
|
||||
let activations_default = vec![
|
||||
"Sigmoid".to_string(),
|
||||
"Tanh".to_string(),
|
||||
"Tanh".to_string(),
|
||||
];
|
||||
let activations = get_attr_opt_owned::<Vec<String>>(node, "activations")?
|
||||
.unwrap_or(activations_default.clone());
|
||||
if activations != activations_default {
|
||||
bail!("LSTM currently only supports default activations ({activations_default:?})");
|
||||
}
|
||||
// activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay
|
||||
if get_attr_opt::<f32>(node, "clip")?.is_some() {
|
||||
bail!("LSTM does not currently support clip attribute");
|
||||
}
|
||||
|
||||
// The shape format of inputs X, initial_h and outputs Y, Y_h.
|
||||
// If 0, the following shapes are expected:
|
||||
// X.shape = [seq_length, batch_size, input_size],
|
||||
// Y.shape = [seq_length, num_directions, batch_size, hidden_size],
|
||||
// initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
|
||||
// If 1, the following shapes are expected:
|
||||
// X.shape = [batch_size, seq_length, input_size],
|
||||
// Y.shape = [batch_size, seq_length, num_directions, hidden_size],
|
||||
// initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
|
||||
let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0);
|
||||
if layout != 0 {
|
||||
bail!("LSTM currently only supports layout == 0");
|
||||
}
|
||||
|
||||
// The input sequences packed (and potentially padded) into one 3-D tensor
|
||||
// with the shape of `[seq_length, batch_size, input_size]`.
|
||||
let x = get(&node.input[0])?;
|
||||
// XXX: depends on layout
|
||||
let (seq_length, batch_size, input_size) = x.dims3()?;
|
||||
// The weight tensor for the gates.
|
||||
// Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0.
|
||||
// The tensor has shape `[num_directions, 4*hidden_size, input_size]`.
|
||||
let w = get(&node.input[1])?;
|
||||
// The recurrence weight tensor.
|
||||
// Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
|
||||
let r = get(&node.input[2])?;
|
||||
|
||||
let get_opt = |i: usize| {
|
||||
node.input
|
||||
.get(i)
|
||||
.filter(|s: &&String| !s.is_empty())
|
||||
.map(|s| get(s))
|
||||
};
|
||||
|
||||
// The bias tensor for input gate.
|
||||
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
|
||||
// This tensor has shape `[num_directions, 8*hidden_size]`.
|
||||
// Optional: If not specified - assumed to be 0.
|
||||
let b_default: Tensor;
|
||||
let b = match get_opt(3) {
|
||||
Some(n) => n?,
|
||||
None => {
|
||||
b_default = Tensor::zeros(
|
||||
(num_directions, 8 * hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&b_default
|
||||
}
|
||||
};
|
||||
|
||||
// Optional tensor specifying lengths of the sequences in a batch.
|
||||
// If not specified - assumed all sequences in the batch to have length `seq_length`.
|
||||
// It has shape `[batch_size]`.
|
||||
let seq_lens_default: Tensor;
|
||||
let seq_lens = match get_opt(4) {
|
||||
Some(n) => n?,
|
||||
None => {
|
||||
seq_lens_default =
|
||||
Tensor::full(seq_length as i64, (batch_size,), x.device())?;
|
||||
&seq_lens_default
|
||||
}
|
||||
};
|
||||
let seq_lens_is_default =
|
||||
(seq_lens.to_vec1::<i64>()?.iter()).all(|e| *e as usize == seq_length);
|
||||
if !seq_lens_is_default {
|
||||
bail!("LSTM currently only supports default value of seq_lens");
|
||||
}
|
||||
|
||||
// Optional initial value of the hidden. If not specified - assumed to be 0.
|
||||
// It has shape `[num_directions, batch_size, hidden_size]`.
|
||||
let initial_h_default: Tensor;
|
||||
let initial_h = match get_opt(5) {
|
||||
Some(n) => n?,
|
||||
_ => {
|
||||
initial_h_default = Tensor::zeros(
|
||||
(num_directions, batch_size, hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&initial_h_default
|
||||
}
|
||||
};
|
||||
|
||||
// Optional initial value of the cell.
|
||||
// If not specified - assumed to be 0.
|
||||
// It has shape `[num_directions, batch_size, hidden_size]`.
|
||||
let initial_c_default: Tensor;
|
||||
let initial_c = match node.input.get(6) {
|
||||
Some(n) if !n.is_empty() => get(n)?,
|
||||
_ => {
|
||||
initial_c_default = Tensor::zeros(
|
||||
(num_directions, batch_size, hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
&initial_c_default
|
||||
}
|
||||
};
|
||||
|
||||
// The weight tensor for peepholes.
|
||||
// Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0.
|
||||
// It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0.
|
||||
let p_default = Tensor::zeros(
|
||||
(num_directions, 3 * hidden_size as usize),
|
||||
DType::F32,
|
||||
x.device(),
|
||||
)?;
|
||||
let p = get_opt(7).unwrap_or(Ok(&p_default))?;
|
||||
let p_is_zeros = (p.to_vec2::<f32>()?.iter()).all(|v| v.iter().all(|e| *e == 0.0));
|
||||
if !p_is_zeros {
|
||||
bail!(
|
||||
"LSTM currently only supports default value of p (a Tensor of all zeroes)"
|
||||
);
|
||||
}
|
||||
|
||||
// these all have [num_directions, ...] shapes
|
||||
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
|
||||
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
|
||||
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
|
||||
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
|
||||
let wb = b.index_select(&idx_wb, 0)?;
|
||||
let rb = b.index_select(&idx_rb, 0)?;
|
||||
let c = initial_c.get(0)?;
|
||||
let h = initial_h.get(0)?;
|
||||
|
||||
// w, r, wb, rb are all iofc but lstm expects ifco
|
||||
// so we need to move some stuff around
|
||||
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
|
||||
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
|
||||
let w = w.index_select(&idx_ifco, 0)?;
|
||||
let r = r.index_select(&idx_ifco, 0)?;
|
||||
let wb = wb.index_select(&idx_ifco, 0)?;
|
||||
let rb = rb.index_select(&idx_ifco, 0)?;
|
||||
let vmap = candle_nn::VarMap::new();
|
||||
vmap.data().lock().unwrap().extend([
|
||||
("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?),
|
||||
("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?),
|
||||
("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?),
|
||||
("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?),
|
||||
]);
|
||||
use candle_nn::rnn::RNN as _;
|
||||
let lstm = candle_nn::rnn::lstm(
|
||||
input_size,
|
||||
hidden_size as usize,
|
||||
candle_nn::rnn::LSTMConfig::default(),
|
||||
candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()),
|
||||
)?;
|
||||
|
||||
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
|
||||
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
for t in 0..seq_length {
|
||||
let x = x.get(t)?;
|
||||
lstm_state = lstm.step(&x, &lstm_state)?;
|
||||
if let Some(h_acc) = &mut h_acc {
|
||||
h_acc.push(lstm_state.clone());
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
|
||||
if let Some(name) = node.output.get(0) {
|
||||
let h_acc = h_acc.as_ref().unwrap();
|
||||
let h_acc = lstm.states_to_tensor(h_acc)?;
|
||||
let h_acc = h_acc.reshape((
|
||||
seq_length,
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?;
|
||||
values.insert(name.clone(), h_acc);
|
||||
}
|
||||
if let Some(name) = node.output.get(1) {
|
||||
values.insert(
|
||||
name.clone(),
|
||||
lstm_state.h().reshape((
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?,
|
||||
);
|
||||
}
|
||||
if let Some(name) = node.output.get(2) {
|
||||
values.insert(
|
||||
name.clone(),
|
||||
lstm_state.c().reshape((
|
||||
num_directions,
|
||||
batch_size,
|
||||
hidden_size as usize,
|
||||
))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,11 +6,13 @@ extern crate accelerate_src;
|
|||
|
||||
use candle::test_utils::to_vec2_round;
|
||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::eval::Value;
|
||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
|
||||
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use candle_onnx::simple_eval;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
|
@ -3514,3 +3516,467 @@ fn test_slice() -> Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lstm() -> Result<()> {
|
||||
// values generated from pytorch, so at least it's close enough to what pytorch does
|
||||
/*
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)
|
||||
|
||||
import torch
|
||||
|
||||
rand_gen = torch.Generator()
|
||||
rand_gen.manual_seed(1)
|
||||
input_size = 3
|
||||
hidden_size = 5
|
||||
batch_size = 1
|
||||
sequence_length = 4
|
||||
number_directions = 1
|
||||
rnn = torch.nn.LSTM(input_size,hidden_size)
|
||||
weight_ih_l0 = torch.randn(rnn.weight_ih_l0.shape, generator=rand_gen)
|
||||
weight_hh_l0 = torch.randn(rnn.weight_hh_l0.shape, generator=rand_gen)
|
||||
bias_ih_l0 = torch.randn(rnn.bias_ih_l0.shape, generator=rand_gen)
|
||||
bias_hh_l0 = torch.randn(rnn.bias_hh_l0.shape, generator=rand_gen)
|
||||
rnn.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
|
||||
rnn.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
|
||||
rnn.bias_ih_l0 = torch.nn.Parameter(bias_ih_l0)
|
||||
rnn.bias_hh_l0 = torch.nn.Parameter(bias_hh_l0)
|
||||
input = torch.randn(sequence_length, batch_size, input_size, generator=rand_gen)
|
||||
h0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
|
||||
c0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
|
||||
output, (hn, cn) = rnn(input, (h0, c0))
|
||||
|
||||
def fmt_tensor(t):
|
||||
return "Tensor::from_vec::<_, f32>(vec!"+ str(t.flatten().tolist()) + ", (" + "".join([str(n)+"," for n in t.shape])+"), &Device::Cpu)?"
|
||||
|
||||
print("let input_size = ", input_size, ";")
|
||||
print("let hidden_size = ", hidden_size, ";")
|
||||
print("let batch_size = ", batch_size, ";")
|
||||
print("let sequence_length = ", sequence_length, ";")
|
||||
print("let number_directions = ", number_directions, ";")
|
||||
print("let weight_ih_l0 = ", fmt_tensor(rnn.weight_ih_l0), ";")
|
||||
print("let weight_hh_l0 = ", fmt_tensor(rnn.weight_hh_l0), ";")
|
||||
print("let bias_ih_l0 = ", fmt_tensor(rnn.bias_ih_l0), ";")
|
||||
print("let bias_hh_l0 = ", fmt_tensor(rnn.bias_hh_l0), ";")
|
||||
print("let input = ", fmt_tensor(input), ";")
|
||||
print("let h0 = ", fmt_tensor(h0), ";")
|
||||
print("let c0 = ", fmt_tensor(c0), ";")
|
||||
print("let output = ", fmt_tensor(output), ";")
|
||||
print("let hn = ", fmt_tensor(hn), ";")
|
||||
print("let cn = ", fmt_tensor(cn), ";")
|
||||
*/
|
||||
let input_size = 3;
|
||||
let hidden_size = 5;
|
||||
let batch_size = 1;
|
||||
let sequence_length = 4;
|
||||
let number_directions = 1;
|
||||
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
-1.5255959033966064,
|
||||
-0.7502318024635315,
|
||||
-0.6539809107780457,
|
||||
-1.6094847917556763,
|
||||
-0.1001671776175499,
|
||||
-0.6091889142990112,
|
||||
-0.9797722697257996,
|
||||
-1.6090962886810303,
|
||||
-0.7121446132659912,
|
||||
0.30372199416160583,
|
||||
-0.777314305305481,
|
||||
-0.25145524740219116,
|
||||
-0.22227048873901367,
|
||||
1.6871134042739868,
|
||||
0.22842517495155334,
|
||||
0.46763551235198975,
|
||||
-0.6969724297523499,
|
||||
-1.1607614755630493,
|
||||
0.6995424032211304,
|
||||
0.1990816295146942,
|
||||
0.8656923770904541,
|
||||
0.2444038987159729,
|
||||
-0.6629113554954529,
|
||||
0.8073082566261292,
|
||||
1.1016806364059448,
|
||||
-0.1759360432624817,
|
||||
-2.2455577850341797,
|
||||
-1.4464579820632935,
|
||||
0.0611552819609642,
|
||||
-0.6177444458007812,
|
||||
-0.7980698347091675,
|
||||
-0.13162320852279663,
|
||||
1.8793457746505737,
|
||||
-0.07213178277015686,
|
||||
0.15777060389518738,
|
||||
-0.7734549045562744,
|
||||
0.1990565061569214,
|
||||
0.04570277780294418,
|
||||
0.15295691788196564,
|
||||
-0.47567880153656006,
|
||||
-0.11101982742547989,
|
||||
0.2927352488040924,
|
||||
-0.1578451544046402,
|
||||
-0.028787139803171158,
|
||||
0.4532545804977417,
|
||||
1.1421611309051514,
|
||||
0.2486107051372528,
|
||||
-1.7754007577896118,
|
||||
-0.025502461940050125,
|
||||
-1.023330569267273,
|
||||
-0.5961851477622986,
|
||||
-1.0055307149887085,
|
||||
0.42854228615760803,
|
||||
1.4760777950286865,
|
||||
-1.7868678569793701,
|
||||
1.610317587852478,
|
||||
-0.703956663608551,
|
||||
-0.18526579439640045,
|
||||
-0.9962350726127625,
|
||||
-0.8312552571296692,
|
||||
],
|
||||
(20, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.4099724292755127,
|
||||
0.4084506630897522,
|
||||
0.25786539912223816,
|
||||
1.095021367073059,
|
||||
-0.5064865946769714,
|
||||
0.09977540373802185,
|
||||
-0.653973400592804,
|
||||
0.731693685054779,
|
||||
-1.456732988357544,
|
||||
1.6089353561401367,
|
||||
0.09376997500658035,
|
||||
-1.2597490549087524,
|
||||
0.25463348627090454,
|
||||
-0.5019572973251343,
|
||||
-1.041200041770935,
|
||||
0.7322672009468079,
|
||||
1.3075355291366577,
|
||||
-1.1627987623214722,
|
||||
0.11963611096143723,
|
||||
-0.1631353348493576,
|
||||
0.6614453196525574,
|
||||
1.1899205446243286,
|
||||
0.8165339231491089,
|
||||
-0.9135236144065857,
|
||||
-0.3538065254688263,
|
||||
0.7639270424842834,
|
||||
-0.5889506936073303,
|
||||
-0.7635973691940308,
|
||||
1.3352056741714478,
|
||||
0.6042736172676086,
|
||||
-0.10344208031892776,
|
||||
-0.15121692419052124,
|
||||
1.2465683221817017,
|
||||
0.505721390247345,
|
||||
0.9505112171173096,
|
||||
1.2966482639312744,
|
||||
0.873796284198761,
|
||||
-0.5602594017982483,
|
||||
1.2857844829559326,
|
||||
0.8168238401412964,
|
||||
-1.464799404144287,
|
||||
-1.2629283666610718,
|
||||
1.122018814086914,
|
||||
1.5663341283798218,
|
||||
2.558138370513916,
|
||||
-0.23336388170719147,
|
||||
-0.013472129590809345,
|
||||
1.8606348037719727,
|
||||
1.549620509147644,
|
||||
0.34762924909591675,
|
||||
0.09300802648067474,
|
||||
0.6147403120994568,
|
||||
0.7123645544052124,
|
||||
-1.7765072584152222,
|
||||
0.3538645803928375,
|
||||
1.1996132135391235,
|
||||
-0.7122589349746704,
|
||||
-0.620034396648407,
|
||||
-0.22813494503498077,
|
||||
-0.7892746329307556,
|
||||
-1.6111117601394653,
|
||||
-1.8716129064559937,
|
||||
0.5430836081504822,
|
||||
0.6606786251068115,
|
||||
0.270527720451355,
|
||||
0.5596919655799866,
|
||||
-0.31839630007743835,
|
||||
1.5117206573486328,
|
||||
-1.363267183303833,
|
||||
-0.9832196235656738,
|
||||
1.5112667083740234,
|
||||
0.6418707370758057,
|
||||
-0.7474458813667297,
|
||||
-0.923438549041748,
|
||||
0.5733984112739563,
|
||||
-0.10929951071739197,
|
||||
0.5181121230125427,
|
||||
0.10653535276651382,
|
||||
0.26924076676368713,
|
||||
1.3247679471969604,
|
||||
0.037456899881362915,
|
||||
-0.6378393173217773,
|
||||
-0.8147554397583008,
|
||||
-0.6895065307617188,
|
||||
0.8436542749404907,
|
||||
1.1657012701034546,
|
||||
0.5269321799278259,
|
||||
1.6192532777786255,
|
||||
-0.963976263999939,
|
||||
0.14152038097381592,
|
||||
-0.1636609584093094,
|
||||
-0.3582225739955902,
|
||||
1.7222793102264404,
|
||||
-0.3035756051540375,
|
||||
0.23887419700622559,
|
||||
1.3440011739730835,
|
||||
0.1032256931066513,
|
||||
1.1003541946411133,
|
||||
-0.3416801989078522,
|
||||
0.947338879108429,
|
||||
],
|
||||
(20, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
-0.568515956401825,
|
||||
0.8375961780548096,
|
||||
1.783660650253296,
|
||||
-0.1954246610403061,
|
||||
0.235193133354187,
|
||||
1.9142433404922485,
|
||||
1.8364111185073853,
|
||||
1.324532389640808,
|
||||
-0.07051458209753036,
|
||||
0.34697940945625305,
|
||||
-0.653679609298706,
|
||||
1.5586202144622803,
|
||||
0.2185661494731903,
|
||||
-0.5743072628974915,
|
||||
1.4571250677108765,
|
||||
1.7709556818008423,
|
||||
-2.0172998905181885,
|
||||
0.42350319027900696,
|
||||
0.5730220079421997,
|
||||
-1.7962429523468018,
|
||||
],
|
||||
(20,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
1.2470403909683228,
|
||||
1.2738511562347412,
|
||||
0.3909492492675781,
|
||||
0.387210488319397,
|
||||
0.14440394937992096,
|
||||
0.7771684527397156,
|
||||
-2.3381125926971436,
|
||||
-0.829120397567749,
|
||||
1.1661391258239746,
|
||||
1.4786574840545654,
|
||||
0.26760873198509216,
|
||||
0.7561198472976685,
|
||||
-0.5873361229896545,
|
||||
-2.061920642852783,
|
||||
0.4304734766483307,
|
||||
0.3376566171646118,
|
||||
-0.3437853455543518,
|
||||
-0.6172260642051697,
|
||||
1.2529692649841309,
|
||||
-0.05141742154955864,
|
||||
],
|
||||
(20,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let input = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.6472128033638,
|
||||
-0.04116716980934143,
|
||||
-0.17749308049678802,
|
||||
-0.500039279460907,
|
||||
0.8672749400138855,
|
||||
-0.27319222688674927,
|
||||
-0.4607681334018707,
|
||||
-0.0990937128663063,
|
||||
0.47284480929374695,
|
||||
1.0049484968185425,
|
||||
-0.2871420383453369,
|
||||
-1.1618621349334717,
|
||||
],
|
||||
(4, 1, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let h0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.02758178487420082,
|
||||
0.5652382373809814,
|
||||
-0.011487378738820553,
|
||||
0.6706400513648987,
|
||||
-0.4929250478744507,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let c0 = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
1.505028486251831,
|
||||
-2.32635498046875,
|
||||
1.6168899536132812,
|
||||
-0.9026237726211548,
|
||||
0.17366823554039001,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let output = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.5956016778945923,
|
||||
-0.01723279245197773,
|
||||
0.11035571992397308,
|
||||
-0.49323174357414246,
|
||||
0.047632161527872086,
|
||||
0.6358451843261719,
|
||||
0.040328118950128555,
|
||||
-0.3788611590862274,
|
||||
-0.7464339733123779,
|
||||
0.20080909132957458,
|
||||
0.5840265154838562,
|
||||
0.1453288197517395,
|
||||
-0.7345298528671265,
|
||||
-0.5214304327964783,
|
||||
0.21903817355632782,
|
||||
0.7420451641082764,
|
||||
0.31943878531455994,
|
||||
-0.04726646468043327,
|
||||
-0.2823849618434906,
|
||||
0.2713133990764618,
|
||||
],
|
||||
(4, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let hn = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.7420451641082764,
|
||||
0.31943878531455994,
|
||||
-0.04726646468043327,
|
||||
-0.2823849618434906,
|
||||
0.2713133990764618,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let cn = Tensor::from_vec::<_, f32>(
|
||||
vec![
|
||||
0.9630558490753174,
|
||||
1.0033069849014282,
|
||||
-1.754899024963379,
|
||||
-1.5967122316360474,
|
||||
0.8252924680709839,
|
||||
],
|
||||
(1, 1, 5),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
// end of generated values
|
||||
|
||||
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "LSTM".to_string(),
|
||||
name: "LSTM_test".to_string(),
|
||||
attribute: vec![AttributeProto {
|
||||
name: "hidden_size".to_string(),
|
||||
r#type: AttributeType::Int.into(),
|
||||
i: hidden_size as i64,
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
input: vec![
|
||||
"input".to_string(),
|
||||
"w".to_string(),
|
||||
"r".to_string(),
|
||||
"b".to_string(), // b
|
||||
"".to_string(), // seq_lens
|
||||
"h".to_string(),
|
||||
"c".to_string(),
|
||||
],
|
||||
output: vec!["output".to_string(), "hn".to_string(), "cn".to_string()],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
input: ["input", "w", "r", "b", "h", "c"]
|
||||
.into_iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
output: ["output", "hn", "cn"]
|
||||
.into_iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
..GraphProto::default()
|
||||
}));
|
||||
// pytorch stores weight and bias as [ifco] but we want it as [iofc]
|
||||
// so we need to re-arrange the tensors a bit
|
||||
let idx_iofc = {
|
||||
let stride = hidden_size as i64;
|
||||
let dev = weight_ih_l0.device();
|
||||
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
|
||||
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
|
||||
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
|
||||
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
|
||||
|
||||
Tensor::cat(&[&idx_i, &idx_o, &idx_f, &idx_g], 0)?
|
||||
};
|
||||
let w = weight_ih_l0.index_select(&idx_iofc, 0)?;
|
||||
let w = w.reshape((number_directions, 4 * hidden_size, input_size))?;
|
||||
let r = weight_hh_l0.index_select(&idx_iofc, 0)?;
|
||||
let r = r.reshape((number_directions, 4 * hidden_size, hidden_size))?;
|
||||
let wb = bias_ih_l0.index_select(&idx_iofc, 0)?;
|
||||
let rb = bias_hh_l0.index_select(&idx_iofc, 0)?;
|
||||
let b = Tensor::cat(&[wb, rb], 0)?.reshape((number_directions, 8 * hidden_size))?;
|
||||
let output = output.reshape((sequence_length, number_directions, batch_size, hidden_size))?;
|
||||
let result = simple_eval(
|
||||
&model,
|
||||
HashMap::from_iter([
|
||||
("input".to_string(), input),
|
||||
("w".to_string(), w),
|
||||
("r".to_string(), r),
|
||||
("b".to_string(), b),
|
||||
("h".to_string(), h0),
|
||||
("c".to_string(), c0),
|
||||
]),
|
||||
)?;
|
||||
let actual_output = result.get("output").unwrap();
|
||||
assert_eq!(output.dims(), actual_output.dims());
|
||||
let actual_hn = result.get("hn").unwrap();
|
||||
assert_eq!(hn.dims(), actual_hn.dims());
|
||||
let actual_cn = result.get("cn").unwrap();
|
||||
assert_eq!(cn.dims(), actual_cn.dims());
|
||||
let diff_close_enough = |a: &Tensor, b| -> Result<_> {
|
||||
let diffs = a.sub(b)?.flatten_all()?.to_vec1::<f32>()?;
|
||||
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
|
||||
};
|
||||
assert!(
|
||||
diff_close_enough(&output, &actual_output)?,
|
||||
"output did not match expected\n{actual_output}\n{output}",
|
||||
);
|
||||
assert!(
|
||||
diff_close_enough(&hn, &actual_hn)?,
|
||||
"hn did not match expected\n{actual_hn}\n{hn}",
|
||||
);
|
||||
assert!(
|
||||
diff_close_enough(&cn, &actual_cn)?,
|
||||
"cn did not match expected\n{actual_cn}\n{cn}",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue