wrpc_transport/frame/conn/
accept.rs

1use core::future::Future;
2use core::ops::{Deref, DerefMut};
3
4use futures::{Stream, StreamExt as _};
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::sync::mpsc;
7
8/// Accepts connections on a transport
9pub trait Accept {
10    /// Transport-specific invocation context
11    type Context: Send + Sync + 'static;
12
13    /// Outgoing byte stream
14    type Outgoing: AsyncWrite + Send + Sync + Unpin + 'static;
15
16    /// Incoming byte stream
17    type Incoming: AsyncRead + Send + Sync + Unpin + 'static;
18
19    /// Accept a connection returning a pair of streams and connection context
20    fn accept(
21        &self,
22    ) -> impl Future<Output = std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)>>;
23}
24
25/// Wrapper returned by [`AcceptExt::map_context`]
26pub struct AcceptMapContext<T, F> {
27    inner: T,
28    f: F,
29}
30
31impl<T, F> Deref for AcceptMapContext<T, F> {
32    type Target = T;
33
34    fn deref(&self) -> &Self::Target {
35        &self.inner
36    }
37}
38
39impl<T, F> DerefMut for AcceptMapContext<T, F> {
40    fn deref_mut(&mut self) -> &mut Self::Target {
41        &mut self.inner
42    }
43}
44
45/// Extension trait for [Accept]
46pub trait AcceptExt: Accept + Sized {
47    /// Maps [`Self::Context`](Accept::Context) to a type `T` using `F`
48    fn map_context<T, F: Fn(Self::Context) -> T>(self, f: F) -> AcceptMapContext<Self, F> {
49        AcceptMapContext { inner: self, f }
50    }
51}
52
53impl<T: Accept> AcceptExt for T {}
54
55impl<T, U, F> Accept for AcceptMapContext<T, F>
56where
57    T: Accept,
58    U: Send + Sync + 'static,
59    F: Fn(T::Context) -> U,
60{
61    type Context = U;
62    type Outgoing = T::Outgoing;
63    type Incoming = T::Incoming;
64
65    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
66        (&self).accept().await
67    }
68}
69
70impl<T, U, F> Accept for &AcceptMapContext<T, F>
71where
72    T: Accept,
73    U: Send + Sync + 'static,
74    F: Fn(T::Context) -> U,
75{
76    type Context = U;
77    type Outgoing = T::Outgoing;
78    type Incoming = T::Incoming;
79
80    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
81        let (cx, tx, rx) = self.inner.accept().await?;
82        Ok(((self.f)(cx), tx, rx))
83    }
84}
85
86/// A wrapper around a [Stream] of connections
87pub struct AcceptStream<T>(tokio::sync::Mutex<T>);
88
89impl<T> From<T> for AcceptStream<T> {
90    fn from(stream: T) -> Self {
91        Self(tokio::sync::Mutex::new(stream))
92    }
93}
94
95impl<T, C, O, I> Accept for AcceptStream<T>
96where
97    T: Stream<Item = (C, O, I)> + Unpin,
98    C: Send + Sync + 'static,
99    O: AsyncWrite + Send + Sync + Unpin + 'static,
100    I: AsyncRead + Send + Sync + Unpin + 'static,
101{
102    type Context = C;
103    type Outgoing = O;
104    type Incoming = I;
105
106    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
107        (&self).accept().await
108    }
109}
110
111impl<T, C, O, I> Accept for &AcceptStream<T>
112where
113    T: Stream<Item = (C, O, I)> + Unpin,
114    C: Send + Sync + 'static,
115    O: AsyncWrite + Send + Sync + Unpin + 'static,
116    I: AsyncRead + Send + Sync + Unpin + 'static,
117{
118    type Context = C;
119    type Outgoing = O;
120    type Incoming = I;
121
122    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
123        let mut stream = self.0.lock().await;
124        let Some((cx, tx, rx)) = stream.next().await else {
125            return Err(std::io::ErrorKind::UnexpectedEof.into());
126        };
127        Ok((cx, tx, rx))
128    }
129}
130
131/// A wrapper around an [mpsc::Receiver] of connections
132pub struct AcceptReceiver<C, O, I>(tokio::sync::Mutex<mpsc::Receiver<(C, O, I)>>);
133
134impl<C, O, I> From<mpsc::Receiver<(C, O, I)>> for AcceptReceiver<C, O, I> {
135    fn from(stream: mpsc::Receiver<(C, O, I)>) -> Self {
136        Self(tokio::sync::Mutex::new(stream))
137    }
138}
139
140impl<C, O, I> Accept for AcceptReceiver<C, O, I>
141where
142    C: Send + Sync + 'static,
143    O: AsyncWrite + Send + Sync + Unpin + 'static,
144    I: AsyncRead + Send + Sync + Unpin + 'static,
145{
146    type Context = C;
147    type Outgoing = O;
148    type Incoming = I;
149
150    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
151        (&self).accept().await
152    }
153}
154
155impl<C, O, I> Accept for &AcceptReceiver<C, O, I>
156where
157    C: Send + Sync + 'static,
158    O: AsyncWrite + Send + Sync + Unpin + 'static,
159    I: AsyncRead + Send + Sync + Unpin + 'static,
160{
161    type Context = C;
162    type Outgoing = O;
163    type Incoming = I;
164
165    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
166        let mut stream = self.0.lock().await;
167        let Some((cx, tx, rx)) = stream.recv().await else {
168            return Err(std::io::ErrorKind::UnexpectedEof.into());
169        };
170        Ok((cx, tx, rx))
171    }
172}
173
174/// A wrapper around an [mpsc::UnboundedReceiver] of connections
175pub struct AcceptUnboundedReceiver<C, O, I>(tokio::sync::Mutex<mpsc::UnboundedReceiver<(C, O, I)>>);
176
177impl<C, O, I> From<mpsc::UnboundedReceiver<(C, O, I)>> for AcceptUnboundedReceiver<C, O, I> {
178    fn from(stream: mpsc::UnboundedReceiver<(C, O, I)>) -> Self {
179        Self(tokio::sync::Mutex::new(stream))
180    }
181}
182
183impl<C, O, I> Accept for AcceptUnboundedReceiver<C, O, I>
184where
185    C: Send + Sync + 'static,
186    O: AsyncWrite + Send + Sync + Unpin + 'static,
187    I: AsyncRead + Send + Sync + Unpin + 'static,
188{
189    type Context = C;
190    type Outgoing = O;
191    type Incoming = I;
192
193    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
194        (&self).accept().await
195    }
196}
197
198impl<C, O, I> Accept for &AcceptUnboundedReceiver<C, O, I>
199where
200    C: Send + Sync + 'static,
201    O: AsyncWrite + Send + Sync + Unpin + 'static,
202    I: AsyncRead + Send + Sync + Unpin + 'static,
203{
204    type Context = C;
205    type Outgoing = O;
206    type Incoming = I;
207
208    async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
209        let mut stream = self.0.lock().await;
210        let Some((cx, tx, rx)) = stream.recv().await else {
211            return Err(std::io::ErrorKind::UnexpectedEof.into());
212        };
213        Ok((cx, tx, rx))
214    }
215}