onnx: implement LSTM op (#2268)

use candle-nn LSTM
This commit is contained in:
shua 2024-08-19 09:06:17 +02:00 committed by GitHub
parent 236b29ff15
commit 31a1075f4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 709 additions and 0 deletions

View File

@ -55,6 +55,10 @@ pub struct LSTMState {
}
impl LSTMState {
pub fn new(h: Tensor, c: Tensor) -> Self {
LSTMState { h, c }
}
/// The hidden state vector, which is also the output of the LSTM.
pub fn h(&self) -> &Tensor {
&self.h

View File

@ -66,6 +66,18 @@ impl Attr for GraphProto {
}
}
impl AttrOwned for Vec<String> {
const TYPE: AttributeType = AttributeType::Strings;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
let mut ret = vec![];
for bytes in attr.strings.iter() {
let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?;
ret.push(s);
}
Ok(ret)
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@ -1310,6 +1322,233 @@ fn simple_eval_(
.broadcast_add(&c.broadcast_mul(&beta)?)?;
values.insert(node.output[0].clone(), output);
}
"LSTM" => {
let direction = get_attr_opt(node, "direction")?.unwrap_or("forward");
if direction != "forward" {
bail!("LSTM currently only supports direction == \"forward\"");
}
let num_directions = if direction == "bidirectional" { 2 } else { 1 };
let hidden_size: i64 = get_attr(node, "hidden_size").copied()?;
let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0);
if input_forget != 0 {
bail!("LSTM currently only supports input_forget == 0");
}
let activations_default = vec![
"Sigmoid".to_string(),
"Tanh".to_string(),
"Tanh".to_string(),
];
let activations = get_attr_opt_owned::<Vec<String>>(node, "activations")?
.unwrap_or(activations_default.clone());
if activations != activations_default {
bail!("LSTM currently only supports default activations ({activations_default:?})");
}
// activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay
if get_attr_opt::<f32>(node, "clip")?.is_some() {
bail!("LSTM does not currently support clip attribute");
}
// The shape format of inputs X, initial_h and outputs Y, Y_h.
// If 0, the following shapes are expected:
// X.shape = [seq_length, batch_size, input_size],
// Y.shape = [seq_length, num_directions, batch_size, hidden_size],
// initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
// If 1, the following shapes are expected:
// X.shape = [batch_size, seq_length, input_size],
// Y.shape = [batch_size, seq_length, num_directions, hidden_size],
// initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0);
if layout != 0 {
bail!("LSTM currently only supports layout == 0");
}
// The input sequences packed (and potentially padded) into one 3-D tensor
// with the shape of `[seq_length, batch_size, input_size]`.
let x = get(&node.input[0])?;
// XXX: depends on layout
let (seq_length, batch_size, input_size) = x.dims3()?;
// The weight tensor for the gates.
// Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0.
// The tensor has shape `[num_directions, 4*hidden_size, input_size]`.
let w = get(&node.input[1])?;
// The recurrence weight tensor.
// Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
// Optional: If not specified - assumed to be 0.
let b_default: Tensor;
let b = match get_opt(3) {
Some(n) => n?,
None => {
b_default = Tensor::zeros(
(num_directions, 8 * hidden_size as usize),
DType::F32,
x.device(),
)?;
&b_default
}
};
// Optional tensor specifying lengths of the sequences in a batch.
// If not specified - assumed all sequences in the batch to have length `seq_length`.
// It has shape `[batch_size]`.
let seq_lens_default: Tensor;
let seq_lens = match get_opt(4) {
Some(n) => n?,
None => {
seq_lens_default =
Tensor::full(seq_length as i64, (batch_size,), x.device())?;
&seq_lens_default
}
};
let seq_lens_is_default =
(seq_lens.to_vec1::<i64>()?.iter()).all(|e| *e as usize == seq_length);
if !seq_lens_is_default {
bail!("LSTM currently only supports default value of seq_lens");
}
// Optional initial value of the hidden. If not specified - assumed to be 0.
// It has shape `[num_directions, batch_size, hidden_size]`.
let initial_h_default: Tensor;
let initial_h = match get_opt(5) {
Some(n) => n?,
_ => {
initial_h_default = Tensor::zeros(
(num_directions, batch_size, hidden_size as usize),
DType::F32,
x.device(),
)?;
&initial_h_default
}
};
// Optional initial value of the cell.
// If not specified - assumed to be 0.
// It has shape `[num_directions, batch_size, hidden_size]`.
let initial_c_default: Tensor;
let initial_c = match node.input.get(6) {
Some(n) if !n.is_empty() => get(n)?,
_ => {
initial_c_default = Tensor::zeros(
(num_directions, batch_size, hidden_size as usize),
DType::F32,
x.device(),
)?;
&initial_c_default
}
};
// The weight tensor for peepholes.
// Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0.
// It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0.
let p_default = Tensor::zeros(
(num_directions, 3 * hidden_size as usize),
DType::F32,
x.device(),
)?;
let p = get_opt(7).unwrap_or(Ok(&p_default))?;
let p_is_zeros = (p.to_vec2::<f32>()?.iter()).all(|v| v.iter().all(|e| *e == 0.0));
if !p_is_zeros {
bail!(
"LSTM currently only supports default value of p (a Tensor of all zeroes)"
);
}
// these all have [num_directions, ...] shapes
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
let c = initial_c.get(0)?;
let h = initial_h.get(0)?;
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
let w = w.index_select(&idx_ifco, 0)?;
let r = r.index_select(&idx_ifco, 0)?;
let wb = wb.index_select(&idx_ifco, 0)?;
let rb = rb.index_select(&idx_ifco, 0)?;
let vmap = candle_nn::VarMap::new();
vmap.data().lock().unwrap().extend([
("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?),
("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?),
("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?),
("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?),
]);
use candle_nn::rnn::RNN as _;
let lstm = candle_nn::rnn::lstm(
input_size,
hidden_size as usize,
candle_nn::rnn::LSTMConfig::default(),
candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()),
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
};
for t in 0..seq_length {
let x = x.get(t)?;
lstm_state = lstm.step(&x, &lstm_state)?;
if let Some(h_acc) = &mut h_acc {
h_acc.push(lstm_state.clone());
}
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((
seq_length,
num_directions,
batch_size,
hidden_size as usize,
))?;
values.insert(name.clone(), h_acc);
}
if let Some(name) = node.output.get(1) {
values.insert(
name.clone(),
lstm_state.h().reshape((
num_directions,
batch_size,
hidden_size as usize,
))?,
);
}
if let Some(name) = node.output.get(2) {
values.insert(
name.clone(),
lstm_state.c().reshape((
num_directions,
batch_size,
hidden_size as usize,
))?,
);
}
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -6,11 +6,13 @@ extern crate accelerate_src;
use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::eval::Value;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use candle_onnx::simple_eval;
use std::collections::HashMap;
const INPUT_X: &str = "x";
@ -3514,3 +3516,467 @@ fn test_slice() -> Result<()> {
Ok(())
}
#[test]
fn test_lstm() -> Result<()> {
// values generated from pytorch, so at least it's close enough to what pytorch does
/*
#!/usr/bin/env python3
# torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)
import torch
rand_gen = torch.Generator()
rand_gen.manual_seed(1)
input_size = 3
hidden_size = 5
batch_size = 1
sequence_length = 4
number_directions = 1
rnn = torch.nn.LSTM(input_size,hidden_size)
weight_ih_l0 = torch.randn(rnn.weight_ih_l0.shape, generator=rand_gen)
weight_hh_l0 = torch.randn(rnn.weight_hh_l0.shape, generator=rand_gen)
bias_ih_l0 = torch.randn(rnn.bias_ih_l0.shape, generator=rand_gen)
bias_hh_l0 = torch.randn(rnn.bias_hh_l0.shape, generator=rand_gen)
rnn.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0)
rnn.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0)
rnn.bias_ih_l0 = torch.nn.Parameter(bias_ih_l0)
rnn.bias_hh_l0 = torch.nn.Parameter(bias_hh_l0)
input = torch.randn(sequence_length, batch_size, input_size, generator=rand_gen)
h0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
c0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen)
output, (hn, cn) = rnn(input, (h0, c0))
def fmt_tensor(t):
return "Tensor::from_vec::<_, f32>(vec!"+ str(t.flatten().tolist()) + ", (" + "".join([str(n)+"," for n in t.shape])+"), &Device::Cpu)?"
print("let input_size = ", input_size, ";")
print("let hidden_size = ", hidden_size, ";")
print("let batch_size = ", batch_size, ";")
print("let sequence_length = ", sequence_length, ";")
print("let number_directions = ", number_directions, ";")
print("let weight_ih_l0 = ", fmt_tensor(rnn.weight_ih_l0), ";")
print("let weight_hh_l0 = ", fmt_tensor(rnn.weight_hh_l0), ";")
print("let bias_ih_l0 = ", fmt_tensor(rnn.bias_ih_l0), ";")
print("let bias_hh_l0 = ", fmt_tensor(rnn.bias_hh_l0), ";")
print("let input = ", fmt_tensor(input), ";")
print("let h0 = ", fmt_tensor(h0), ";")
print("let c0 = ", fmt_tensor(c0), ";")
print("let output = ", fmt_tensor(output), ";")
print("let hn = ", fmt_tensor(hn), ";")
print("let cn = ", fmt_tensor(cn), ";")
*/
let input_size = 3;
let hidden_size = 5;
let batch_size = 1;
let sequence_length = 4;
let number_directions = 1;
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-1.5255959033966064,
-0.7502318024635315,
-0.6539809107780457,
-1.6094847917556763,
-0.1001671776175499,
-0.6091889142990112,
-0.9797722697257996,
-1.6090962886810303,
-0.7121446132659912,
0.30372199416160583,
-0.777314305305481,
-0.25145524740219116,
-0.22227048873901367,
1.6871134042739868,
0.22842517495155334,
0.46763551235198975,
-0.6969724297523499,
-1.1607614755630493,
0.6995424032211304,
0.1990816295146942,
0.8656923770904541,
0.2444038987159729,
-0.6629113554954529,
0.8073082566261292,
1.1016806364059448,
-0.1759360432624817,
-2.2455577850341797,
-1.4464579820632935,
0.0611552819609642,
-0.6177444458007812,
-0.7980698347091675,
-0.13162320852279663,
1.8793457746505737,
-0.07213178277015686,
0.15777060389518738,
-0.7734549045562744,
0.1990565061569214,
0.04570277780294418,
0.15295691788196564,
-0.47567880153656006,
-0.11101982742547989,
0.2927352488040924,
-0.1578451544046402,
-0.028787139803171158,
0.4532545804977417,
1.1421611309051514,
0.2486107051372528,
-1.7754007577896118,
-0.025502461940050125,
-1.023330569267273,
-0.5961851477622986,
-1.0055307149887085,
0.42854228615760803,
1.4760777950286865,
-1.7868678569793701,
1.610317587852478,
-0.703956663608551,
-0.18526579439640045,
-0.9962350726127625,
-0.8312552571296692,
],
(20, 3),
&Device::Cpu,
)?;
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
0.4099724292755127,
0.4084506630897522,
0.25786539912223816,
1.095021367073059,
-0.5064865946769714,
0.09977540373802185,
-0.653973400592804,
0.731693685054779,
-1.456732988357544,
1.6089353561401367,
0.09376997500658035,
-1.2597490549087524,
0.25463348627090454,
-0.5019572973251343,
-1.041200041770935,
0.7322672009468079,
1.3075355291366577,
-1.1627987623214722,
0.11963611096143723,
-0.1631353348493576,
0.6614453196525574,
1.1899205446243286,
0.8165339231491089,
-0.9135236144065857,
-0.3538065254688263,
0.7639270424842834,
-0.5889506936073303,
-0.7635973691940308,
1.3352056741714478,
0.6042736172676086,
-0.10344208031892776,
-0.15121692419052124,
1.2465683221817017,
0.505721390247345,
0.9505112171173096,
1.2966482639312744,
0.873796284198761,
-0.5602594017982483,
1.2857844829559326,
0.8168238401412964,
-1.464799404144287,
-1.2629283666610718,
1.122018814086914,
1.5663341283798218,
2.558138370513916,
-0.23336388170719147,
-0.013472129590809345,
1.8606348037719727,
1.549620509147644,
0.34762924909591675,
0.09300802648067474,
0.6147403120994568,
0.7123645544052124,
-1.7765072584152222,
0.3538645803928375,
1.1996132135391235,
-0.7122589349746704,
-0.620034396648407,
-0.22813494503498077,
-0.7892746329307556,
-1.6111117601394653,
-1.8716129064559937,
0.5430836081504822,
0.6606786251068115,
0.270527720451355,
0.5596919655799866,
-0.31839630007743835,
1.5117206573486328,
-1.363267183303833,
-0.9832196235656738,
1.5112667083740234,
0.6418707370758057,
-0.7474458813667297,
-0.923438549041748,
0.5733984112739563,
-0.10929951071739197,
0.5181121230125427,
0.10653535276651382,
0.26924076676368713,
1.3247679471969604,
0.037456899881362915,
-0.6378393173217773,
-0.8147554397583008,
-0.6895065307617188,
0.8436542749404907,
1.1657012701034546,
0.5269321799278259,
1.6192532777786255,
-0.963976263999939,
0.14152038097381592,
-0.1636609584093094,
-0.3582225739955902,
1.7222793102264404,
-0.3035756051540375,
0.23887419700622559,
1.3440011739730835,
0.1032256931066513,
1.1003541946411133,
-0.3416801989078522,
0.947338879108429,
],
(20, 5),
&Device::Cpu,
)?;
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-0.568515956401825,
0.8375961780548096,
1.783660650253296,
-0.1954246610403061,
0.235193133354187,
1.9142433404922485,
1.8364111185073853,
1.324532389640808,
-0.07051458209753036,
0.34697940945625305,
-0.653679609298706,
1.5586202144622803,
0.2185661494731903,
-0.5743072628974915,
1.4571250677108765,
1.7709556818008423,
-2.0172998905181885,
0.42350319027900696,
0.5730220079421997,
-1.7962429523468018,
],
(20,),
&Device::Cpu,
)?;
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
1.2470403909683228,
1.2738511562347412,
0.3909492492675781,
0.387210488319397,
0.14440394937992096,
0.7771684527397156,
-2.3381125926971436,
-0.829120397567749,
1.1661391258239746,
1.4786574840545654,
0.26760873198509216,
0.7561198472976685,
-0.5873361229896545,
-2.061920642852783,
0.4304734766483307,
0.3376566171646118,
-0.3437853455543518,
-0.6172260642051697,
1.2529692649841309,
-0.05141742154955864,
],
(20,),
&Device::Cpu,
)?;
let input = Tensor::from_vec::<_, f32>(
vec![
0.6472128033638,
-0.04116716980934143,
-0.17749308049678802,
-0.500039279460907,
0.8672749400138855,
-0.27319222688674927,
-0.4607681334018707,
-0.0990937128663063,
0.47284480929374695,
1.0049484968185425,
-0.2871420383453369,
-1.1618621349334717,
],
(4, 1, 3),
&Device::Cpu,
)?;
let h0 = Tensor::from_vec::<_, f32>(
vec![
0.02758178487420082,
0.5652382373809814,
-0.011487378738820553,
0.6706400513648987,
-0.4929250478744507,
],
(1, 1, 5),
&Device::Cpu,
)?;
let c0 = Tensor::from_vec::<_, f32>(
vec![
1.505028486251831,
-2.32635498046875,
1.6168899536132812,
-0.9026237726211548,
0.17366823554039001,
],
(1, 1, 5),
&Device::Cpu,
)?;
let output = Tensor::from_vec::<_, f32>(
vec![
0.5956016778945923,
-0.01723279245197773,
0.11035571992397308,
-0.49323174357414246,
0.047632161527872086,
0.6358451843261719,
0.040328118950128555,
-0.3788611590862274,
-0.7464339733123779,
0.20080909132957458,
0.5840265154838562,
0.1453288197517395,
-0.7345298528671265,
-0.5214304327964783,
0.21903817355632782,
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
],
(4, 1, 5),
&Device::Cpu,
)?;
let hn = Tensor::from_vec::<_, f32>(
vec![
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
],
(1, 1, 5),
&Device::Cpu,
)?;
let cn = Tensor::from_vec::<_, f32>(
vec![
0.9630558490753174,
1.0033069849014282,
-1.754899024963379,
-1.5967122316360474,
0.8252924680709839,
],
(1, 1, 5),
&Device::Cpu,
)?;
// end of generated values
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "LSTM".to_string(),
name: "LSTM_test".to_string(),
attribute: vec![AttributeProto {
name: "hidden_size".to_string(),
r#type: AttributeType::Int.into(),
i: hidden_size as i64,
..AttributeProto::default()
}],
input: vec![
"input".to_string(),
"w".to_string(),
"r".to_string(),
"b".to_string(), // b
"".to_string(), // seq_lens
"h".to_string(),
"c".to_string(),
],
output: vec!["output".to_string(), "hn".to_string(), "cn".to_string()],
..NodeProto::default()
}],
input: ["input", "w", "r", "b", "h", "c"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
output: ["output", "hn", "cn"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
..GraphProto::default()
}));
// pytorch stores weight and bias as [ifco] but we want it as [iofc]
// so we need to re-arrange the tensors a bit
let idx_iofc = {
let stride = hidden_size as i64;
let dev = weight_ih_l0.device();
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
Tensor::cat(&[&idx_i, &idx_o, &idx_f, &idx_g], 0)?
};
let w = weight_ih_l0.index_select(&idx_iofc, 0)?;
let w = w.reshape((number_directions, 4 * hidden_size, input_size))?;
let r = weight_hh_l0.index_select(&idx_iofc, 0)?;
let r = r.reshape((number_directions, 4 * hidden_size, hidden_size))?;
let wb = bias_ih_l0.index_select(&idx_iofc, 0)?;
let rb = bias_hh_l0.index_select(&idx_iofc, 0)?;
let b = Tensor::cat(&[wb, rb], 0)?.reshape((number_directions, 8 * hidden_size))?;
let output = output.reshape((sequence_length, number_directions, batch_size, hidden_size))?;
let result = simple_eval(
&model,
HashMap::from_iter([
("input".to_string(), input),
("w".to_string(), w),
("r".to_string(), r),
("b".to_string(), b),
("h".to_string(), h0),
("c".to_string(), c0),
]),
)?;
let actual_output = result.get("output").unwrap();
assert_eq!(output.dims(), actual_output.dims());
let actual_hn = result.get("hn").unwrap();
assert_eq!(hn.dims(), actual_hn.dims());
let actual_cn = result.get("cn").unwrap();
assert_eq!(cn.dims(), actual_cn.dims());
let diff_close_enough = |a: &Tensor, b| -> Result<_> {
let diffs = a.sub(b)?.flatten_all()?.to_vec1::<f32>()?;
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
};
assert!(
diff_close_enough(&output, &actual_output)?,
"output did not match expected\n{actual_output}\n{output}",
);
assert!(
diff_close_enough(&hn, &actual_hn)?,
"hn did not match expected\n{actual_hn}\n{hn}",
);
assert!(
diff_close_enough(&cn, &actual_cn)?,
"cn did not match expected\n{actual_cn}\n{cn}",
);
Ok(())
}