wrpc_transport/frame/
oneshot.rs

1//! wRPC transport stream framing
2
3use core::future::Future;
4
5use bytes::Bytes;
6use tokio::io::{duplex, split, AsyncRead, AsyncWrite, DuplexStream, ReadHalf, WriteHalf};
7use tracing::instrument;
8
9use crate::frame::{invoke, Incoming, Outgoing};
10use crate::{Accept, Invoke};
11
12/// [Invoke] and [Accept] implementation in terms of a single stream pair.
13///
14/// Either [`Invoke::invoke`] or [`Accept::accept`] can only be called at most once
15/// on [Oneshot], repeated calls with return an error
16#[derive(Debug)]
17pub struct Oneshot<I, O>(std::sync::Mutex<Option<(I, O)>>);
18
19impl<I, O> From<(I, O)> for Oneshot<I, O> {
20    fn from((rx, tx): (I, O)) -> Self {
21        Self(std::sync::Mutex::new(Some((rx, tx))))
22    }
23}
24
25impl From<DuplexStream> for Oneshot<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>> {
26    fn from(stream: DuplexStream) -> Self {
27        split(stream).into()
28    }
29}
30
31impl Oneshot<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>> {
32    /// Creates a pair of connected [Oneshot] using [tokio::io::duplex].
33    pub fn duplex(max_buf_size: usize) -> (Self, Self) {
34        let (a, b) = duplex(max_buf_size);
35        (a.into(), b.into())
36    }
37}
38
39impl<I, O> Oneshot<I, O> {
40    /// Returns the inner stream pair if [Oneshot] has not been used yet or an error.
41    pub fn try_take_inner(&self) -> std::io::Result<(I, O)> {
42        match self.0.try_lock().map(|mut stream| stream.take()) {
43            Ok(Some((rx, tx))) => Ok((rx, tx)),
44            Ok(None) | Err(std::sync::TryLockError::WouldBlock) => Err(std::io::Error::new(
45                std::io::ErrorKind::UnexpectedEof,
46                "stream was already used",
47            )),
48            Err(std::sync::TryLockError::Poisoned(..)) => {
49                Err(std::io::Error::other("stream lock poisoned"))
50            }
51        }
52    }
53}
54
55impl<I, O> Invoke for Oneshot<I, O>
56where
57    I: AsyncRead + Send + Unpin + 'static,
58    O: AsyncWrite + Send + Unpin + 'static,
59{
60    type Context = ();
61    type Outgoing = Outgoing;
62    type Incoming = Incoming;
63
64    async fn invoke<P>(
65        &self,
66        cx: Self::Context,
67        instance: &str,
68        func: &str,
69        params: Bytes,
70        paths: impl AsRef<[P]> + Send,
71    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
72    where
73        P: AsRef<[Option<usize>]> + Send + Sync,
74    {
75        (&self).invoke(cx, instance, func, params, paths).await
76    }
77}
78
79impl<I, O> Invoke for &Oneshot<I, O>
80where
81    I: AsyncRead + Send + Unpin + 'static,
82    O: AsyncWrite + Send + Unpin + 'static,
83{
84    type Context = ();
85    type Outgoing = Outgoing;
86    type Incoming = Incoming;
87
88    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
89    fn invoke<P>(
90        &self,
91        (): Self::Context,
92        instance: &str,
93        func: &str,
94        params: Bytes,
95        paths: impl AsRef<[P]> + Send,
96    ) -> impl Future<Output = anyhow::Result<(Self::Outgoing, Self::Incoming)>>
97    where
98        P: AsRef<[Option<usize>]> + Send + Sync,
99    {
100        let stream = self.try_take_inner();
101        async move {
102            let (rx, tx) = stream?;
103            invoke(tx, rx, instance, func, params, paths).await
104        }
105    }
106}
107
108impl<I, O> Accept for Oneshot<I, O>
109where
110    I: AsyncRead + Send + Sync + Unpin + 'static,
111    O: AsyncWrite + Send + Sync + Unpin + 'static,
112{
113    type Context = ();
114    type Outgoing = O;
115    type Incoming = I;
116
117    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
118        (&self).accept().await
119    }
120}
121
122impl<I, O> Accept for &Oneshot<I, O>
123where
124    I: AsyncRead + Send + Sync + Unpin + 'static,
125    O: AsyncWrite + Send + Sync + Unpin + 'static,
126{
127    type Context = ();
128    type Outgoing = O;
129    type Incoming = I;
130
131    #[instrument(level = "trace", skip(self))]
132    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
133        let (rx, tx) = self.try_take_inner()?;
134        Ok(((), tx, rx))
135    }
136}