wrpc_transport/frame/conn/
accept.rs1use core::future::Future;
2use core::ops::{Deref, DerefMut};
3
4use futures::{Stream, StreamExt as _};
5use tokio::io::{AsyncRead, AsyncWrite};
6use tokio::sync::mpsc;
7
8pub trait Accept {
10 type Context: Send + Sync + 'static;
12
13 type Outgoing: AsyncWrite + Send + Sync + Unpin + 'static;
15
16 type Incoming: AsyncRead + Send + Sync + Unpin + 'static;
18
19 fn accept(
21 &self,
22 ) -> impl Future<Output = std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)>>;
23}
24
25pub struct AcceptMapContext<T, F> {
27 inner: T,
28 f: F,
29}
30
31impl<T, F> Deref for AcceptMapContext<T, F> {
32 type Target = T;
33
34 fn deref(&self) -> &Self::Target {
35 &self.inner
36 }
37}
38
39impl<T, F> DerefMut for AcceptMapContext<T, F> {
40 fn deref_mut(&mut self) -> &mut Self::Target {
41 &mut self.inner
42 }
43}
44
45pub trait AcceptExt: Accept + Sized {
47 fn map_context<T, F: Fn(Self::Context) -> T>(self, f: F) -> AcceptMapContext<Self, F> {
49 AcceptMapContext { inner: self, f }
50 }
51}
52
53impl<T: Accept> AcceptExt for T {}
54
55impl<T, U, F> Accept for AcceptMapContext<T, F>
56where
57 T: Accept,
58 U: Send + Sync + 'static,
59 F: Fn(T::Context) -> U,
60{
61 type Context = U;
62 type Outgoing = T::Outgoing;
63 type Incoming = T::Incoming;
64
65 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
66 (&self).accept().await
67 }
68}
69
70impl<T, U, F> Accept for &AcceptMapContext<T, F>
71where
72 T: Accept,
73 U: Send + Sync + 'static,
74 F: Fn(T::Context) -> U,
75{
76 type Context = U;
77 type Outgoing = T::Outgoing;
78 type Incoming = T::Incoming;
79
80 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
81 let (cx, tx, rx) = self.inner.accept().await?;
82 Ok(((self.f)(cx), tx, rx))
83 }
84}
85
86pub struct AcceptStream<T>(tokio::sync::Mutex<T>);
88
89impl<T> From<T> for AcceptStream<T> {
90 fn from(stream: T) -> Self {
91 Self(tokio::sync::Mutex::new(stream))
92 }
93}
94
95impl<T, C, O, I> Accept for AcceptStream<T>
96where
97 T: Stream<Item = (C, O, I)> + Unpin,
98 C: Send + Sync + 'static,
99 O: AsyncWrite + Send + Sync + Unpin + 'static,
100 I: AsyncRead + Send + Sync + Unpin + 'static,
101{
102 type Context = C;
103 type Outgoing = O;
104 type Incoming = I;
105
106 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
107 (&self).accept().await
108 }
109}
110
111impl<T, C, O, I> Accept for &AcceptStream<T>
112where
113 T: Stream<Item = (C, O, I)> + Unpin,
114 C: Send + Sync + 'static,
115 O: AsyncWrite + Send + Sync + Unpin + 'static,
116 I: AsyncRead + Send + Sync + Unpin + 'static,
117{
118 type Context = C;
119 type Outgoing = O;
120 type Incoming = I;
121
122 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
123 let mut stream = self.0.lock().await;
124 let Some((cx, tx, rx)) = stream.next().await else {
125 return Err(std::io::ErrorKind::UnexpectedEof.into());
126 };
127 Ok((cx, tx, rx))
128 }
129}
130
131pub struct AcceptReceiver<C, O, I>(tokio::sync::Mutex<mpsc::Receiver<(C, O, I)>>);
133
134impl<C, O, I> From<mpsc::Receiver<(C, O, I)>> for AcceptReceiver<C, O, I> {
135 fn from(stream: mpsc::Receiver<(C, O, I)>) -> Self {
136 Self(tokio::sync::Mutex::new(stream))
137 }
138}
139
140impl<C, O, I> Accept for AcceptReceiver<C, O, I>
141where
142 C: Send + Sync + 'static,
143 O: AsyncWrite + Send + Sync + Unpin + 'static,
144 I: AsyncRead + Send + Sync + Unpin + 'static,
145{
146 type Context = C;
147 type Outgoing = O;
148 type Incoming = I;
149
150 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
151 (&self).accept().await
152 }
153}
154
155impl<C, O, I> Accept for &AcceptReceiver<C, O, I>
156where
157 C: Send + Sync + 'static,
158 O: AsyncWrite + Send + Sync + Unpin + 'static,
159 I: AsyncRead + Send + Sync + Unpin + 'static,
160{
161 type Context = C;
162 type Outgoing = O;
163 type Incoming = I;
164
165 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
166 let mut stream = self.0.lock().await;
167 let Some((cx, tx, rx)) = stream.recv().await else {
168 return Err(std::io::ErrorKind::UnexpectedEof.into());
169 };
170 Ok((cx, tx, rx))
171 }
172}
173
174pub struct AcceptUnboundedReceiver<C, O, I>(tokio::sync::Mutex<mpsc::UnboundedReceiver<(C, O, I)>>);
176
177impl<C, O, I> From<mpsc::UnboundedReceiver<(C, O, I)>> for AcceptUnboundedReceiver<C, O, I> {
178 fn from(stream: mpsc::UnboundedReceiver<(C, O, I)>) -> Self {
179 Self(tokio::sync::Mutex::new(stream))
180 }
181}
182
183impl<C, O, I> Accept for AcceptUnboundedReceiver<C, O, I>
184where
185 C: Send + Sync + 'static,
186 O: AsyncWrite + Send + Sync + Unpin + 'static,
187 I: AsyncRead + Send + Sync + Unpin + 'static,
188{
189 type Context = C;
190 type Outgoing = O;
191 type Incoming = I;
192
193 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
194 (&self).accept().await
195 }
196}
197
198impl<C, O, I> Accept for &AcceptUnboundedReceiver<C, O, I>
199where
200 C: Send + Sync + 'static,
201 O: AsyncWrite + Send + Sync + Unpin + 'static,
202 I: AsyncRead + Send + Sync + Unpin + 'static,
203{
204 type Context = C;
205 type Outgoing = O;
206 type Incoming = I;
207
208 async fn accept(&self) -> std::io::Result<(Self::Context, Self::Outgoing, Self::Incoming)> {
209 let mut stream = self.0.lock().await;
210 let Some((cx, tx, rx)) = stream.recv().await else {
211 return Err(std::io::ErrorKind::UnexpectedEof.into());
212 };
213 Ok((cx, tx, rx))
214 }
215}