diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index 6e37af9..488a317 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -345,33 +345,35 @@ where } } -// TODO: Implement TryParallel -// pub struct TryParallel { -// op1: Op1, -// op2: Op2, -// } +// Implement TryParallel +pub struct TryParallel { + op1: Op1, + op2: Op2, +} -// impl TryParallel { -// pub fn new(op1: Op1, op2: Op2) -> Self { -// Self { op1, op2 } -// } -// } +impl TryParallel { + pub fn new(op1: Op1, op2: Op2) -> Self { + Self { op1, op2 } + } +} -// impl TryOp for TryParallel -// where -// Op1: TryOp, -// Op2: TryOp, -// { -// type Input = Op1::Input; -// type Output = (Op1::Output, Op2::Output); -// type Error = Op1::Error; +impl op::Op for TryParallel +where + Op1: TryOp, + Op2: TryOp, +{ + type Input = Op1::Input; + type Output = Result<(Op1::Output, Op2::Output), Op1::Error>; -// #[inline] -// async fn try_call(&self, input: Self::Input) -> Result { -// let (output1, output2) = tokio::join!(self.op1.try_call(input.clone()), self.op2.try_call(input)); -// Ok((output1?, output2?)) -// } -// } + #[inline] + async fn try_call(&self, input: Self::Input) -> Result<(Op1::Output, Op2::Output), Op1::Error> { + use futures::try_join; + try_join!( + self.op1.try_call(input.clone()), + self.op2.try_call(input) + ) + } +} #[cfg(test)] mod tests { @@ -472,4 +474,29 @@ mod tests { let result = pipeline.try_call(1).await.unwrap(); assert_eq!(result, 15); } + + #[tokio::test] + async fn test_try_parallel() { + let op1 = map(|x: i32| { + if x % 2 == 0 { + Ok(x + 1) + } else { + Err("x is odd") + } + }); + let op2 = map(|x: i32| { + if x % 2 == 0 { + Ok(x * 2) + } else { + Err("x is odd") + } + }); + let pipeline = TryParallel::new(op1, op2); + + let result = pipeline.try_call(2).await; + assert_eq!(result, Ok((3, 4))); + + let result = pipeline.try_call(1).await; + assert_eq!(result, Err("x is odd")); + } }