wrpc_transport/frame/
oneshot.rs1use core::future::Future;
4
5use bytes::Bytes;
6use tokio::io::{duplex, split, AsyncRead, AsyncWrite, DuplexStream, ReadHalf, WriteHalf};
7use tracing::instrument;
8
9use crate::frame::{invoke, Incoming, Outgoing};
10use crate::{Accept, Invoke};
11
12#[derive(Debug)]
17pub struct Oneshot<I, O>(std::sync::Mutex<Option<(I, O)>>);
18
19impl<I, O> From<(I, O)> for Oneshot<I, O> {
20 fn from((rx, tx): (I, O)) -> Self {
21 Self(std::sync::Mutex::new(Some((rx, tx))))
22 }
23}
24
25impl From<DuplexStream> for Oneshot<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>> {
26 fn from(stream: DuplexStream) -> Self {
27 split(stream).into()
28 }
29}
30
31impl Oneshot<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>> {
32 pub fn duplex(max_buf_size: usize) -> (Self, Self) {
34 let (a, b) = duplex(max_buf_size);
35 (a.into(), b.into())
36 }
37}
38
39impl<I, O> Oneshot<I, O> {
40 pub fn try_take_inner(&self) -> std::io::Result<(I, O)> {
42 match self.0.try_lock().map(|mut stream| stream.take()) {
43 Ok(Some((rx, tx))) => Ok((rx, tx)),
44 Ok(None) | Err(std::sync::TryLockError::WouldBlock) => Err(std::io::Error::new(
45 std::io::ErrorKind::UnexpectedEof,
46 "stream was already used",
47 )),
48 Err(std::sync::TryLockError::Poisoned(..)) => {
49 Err(std::io::Error::other("stream lock poisoned"))
50 }
51 }
52 }
53}
54
55impl<I, O> Invoke for Oneshot<I, O>
56where
57 I: AsyncRead + Send + Unpin + 'static,
58 O: AsyncWrite + Send + Unpin + 'static,
59{
60 type Context = ();
61 type Outgoing = Outgoing;
62 type Incoming = Incoming;
63
64 async fn invoke<P>(
65 &self,
66 cx: Self::Context,
67 instance: &str,
68 func: &str,
69 params: Bytes,
70 paths: impl AsRef<[P]> + Send,
71 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
72 where
73 P: AsRef<[Option<usize>]> + Send + Sync,
74 {
75 (&self).invoke(cx, instance, func, params, paths).await
76 }
77}
78
79impl<I, O> Invoke for &Oneshot<I, O>
80where
81 I: AsyncRead + Send + Unpin + 'static,
82 O: AsyncWrite + Send + Unpin + 'static,
83{
84 type Context = ();
85 type Outgoing = Outgoing;
86 type Incoming = Incoming;
87
88 #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
89 fn invoke<P>(
90 &self,
91 (): Self::Context,
92 instance: &str,
93 func: &str,
94 params: Bytes,
95 paths: impl AsRef<[P]> + Send,
96 ) -> impl Future<Output = anyhow::Result<(Self::Outgoing, Self::Incoming)>>
97 where
98 P: AsRef<[Option<usize>]> + Send + Sync,
99 {
100 let stream = self.try_take_inner();
101 async move {
102 let (rx, tx) = stream?;
103 invoke(tx, rx, instance, func, params, paths).await
104 }
105 }
106}
107
108impl<I, O> Accept for Oneshot<I, O>
109where
110 I: AsyncRead + Send + Sync + Unpin + 'static,
111 O: AsyncWrite + Send + Sync + Unpin + 'static,
112{
113 type Context = ();
114 type Outgoing = O;
115 type Incoming = I;
116
117 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
118 (&self).accept().await
119 }
120}
121
122impl<I, O> Accept for &Oneshot<I, O>
123where
124 I: AsyncRead + Send + Sync + Unpin + 'static,
125 O: AsyncWrite + Send + Sync + Unpin + 'static,
126{
127 type Context = ();
128 type Outgoing = O;
129 type Incoming = I;
130
131 #[instrument(level = "trace", skip(self))]
132 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
133 let (rx, tx) = self.try_take_inner()?;
134 Ok(((), tx, rx))
135 }
136}