1use crate::driver::streams::bilocal::StreamBiLocalQuic;
2use crate::driver::streams::unilocal::StreamUniLocalQuic;
3use crate::driver::streams::ProtoWriteError;
4use crate::driver::streams::QuicRecvStream;
5use crate::driver::streams::QuicSendStream;
6use crate::error::ClosedStream;
7use crate::error::StreamOpeningError;
8use crate::error::StreamReadError;
9use crate::error::StreamReadExactError;
10use crate::error::StreamWriteError;
11use crate::SessionId;
12use crate::StreamId;
13use crate::VarInt;
14use std::future::Future;
15use std::pin::Pin;
16use std::task::Context;
17use std::task::Poll;
18use tokio::io::ReadBuf;
19use wtransport_proto::stream_header::StreamHeader;
20
21#[derive(Debug)]
23pub struct SendStream(QuicSendStream);
24
25impl SendStream {
26 #[inline(always)]
27 pub(crate) fn new(stream: QuicSendStream) -> Self {
28 Self(stream)
29 }
30
31 #[inline(always)]
37 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, StreamWriteError> {
38 self.0.write(buf).await
39 }
40
41 #[inline(always)]
43 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), StreamWriteError> {
44 self.0.write_all(buf).await
45 }
46
47 #[inline(always)]
52 pub async fn finish(&mut self) -> Result<(), StreamWriteError> {
53 self.0.finish().await
54 }
55
56 #[inline(always)]
58 pub fn id(&self) -> StreamId {
59 self.0.id()
60 }
61
62 #[inline(always)]
70 pub fn set_priority(&self, priority: i32) {
71 self.0.set_priority(priority);
72 }
73
74 #[inline(always)]
80 pub fn priority(&self) -> i32 {
81 self.0.priority()
82 }
83
84 #[inline(always)]
92 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
93 self.0.reset(error_code)
94 }
95
96 #[inline(always)]
102 pub async fn stopped(&mut self) -> StreamWriteError {
103 self.0.stopped().await
104 }
105
106 #[cfg(feature = "quinn")]
108 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
109 #[inline(always)]
110 pub fn quic_stream(&self) -> &quinn::SendStream {
111 self.0.quic_stream()
112 }
113
114 #[cfg(feature = "quinn")]
116 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
117 #[inline(always)]
118 pub fn quic_stream_mut(&mut self) -> &mut quinn::SendStream {
119 self.0.quic_stream_mut()
120 }
121}
122
123#[derive(Debug)]
125pub struct RecvStream(QuicRecvStream);
126
127impl RecvStream {
128 #[inline(always)]
129 pub(crate) fn new(stream: QuicRecvStream) -> Self {
130 Self(stream)
131 }
132
133 #[inline(always)]
137 pub async fn read(&mut self, buf: &mut [u8]) -> Result<Option<usize>, StreamReadError> {
138 self.0.read(buf).await
139 }
140
141 pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), StreamReadExactError> {
146 self.0.read_exact(buf).await
147 }
148
149 pub fn stop(mut self, error_code: VarInt) {
153 let _ = self.0.stop(error_code);
154 }
155
156 #[inline(always)]
158 pub fn id(&self) -> StreamId {
159 self.0.id()
160 }
161
162 #[cfg(feature = "quinn")]
164 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
165 #[inline(always)]
166 pub fn quic_stream(&self) -> &quinn::RecvStream {
167 self.0.quic_stream()
168 }
169
170 #[cfg(feature = "quinn")]
172 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
173 #[inline(always)]
174 pub fn quic_stream_mut(&mut self) -> &mut quinn::RecvStream {
175 self.0.quic_stream_mut()
176 }
177}
178
179#[derive(Debug)]
206pub struct BiStream((SendStream, RecvStream));
207
208impl BiStream {
209 pub fn join(s: (SendStream, RecvStream)) -> Self {
211 Self(s)
212 }
213
214 pub fn split(self) -> (SendStream, RecvStream) {
216 self.0
217 }
218
219 pub fn send(&self) -> &SendStream {
221 &self.0 .0
222 }
223
224 pub fn send_mut(&mut self) -> &mut SendStream {
226 &mut self.0 .0
227 }
228
229 pub fn recv(&self) -> &RecvStream {
231 &self.0 .1
232 }
233
234 pub fn recv_mut(&mut self) -> &mut RecvStream {
236 &mut self.0 .1
237 }
238}
239
240impl From<(SendStream, RecvStream)> for BiStream {
241 fn from(value: (SendStream, RecvStream)) -> Self {
242 Self::join(value)
243 }
244}
245
246impl From<BiStream> for (SendStream, RecvStream) {
247 fn from(value: BiStream) -> Self {
248 value.split()
249 }
250}
251
252impl tokio::io::AsyncWrite for SendStream {
253 #[inline(always)]
254 fn poll_write(
255 mut self: Pin<&mut Self>,
256 cx: &mut Context<'_>,
257 buf: &[u8],
258 ) -> Poll<std::io::Result<usize>> {
259 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
260 }
261
262 #[inline(always)]
263 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
264 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
265 }
266
267 #[inline(always)]
268 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
269 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
270 }
271
272 #[inline(always)]
273 fn poll_write_vectored(
274 mut self: Pin<&mut Self>,
275 cx: &mut Context<'_>,
276 bufs: &[std::io::IoSlice<'_>],
277 ) -> Poll<Result<usize, std::io::Error>> {
278 tokio::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs)
279 }
280
281 #[inline(always)]
282 fn is_write_vectored(&self) -> bool {
283 tokio::io::AsyncWrite::is_write_vectored(&self.0)
284 }
285}
286
287impl tokio::io::AsyncRead for RecvStream {
288 #[inline(always)]
289 fn poll_read(
290 mut self: Pin<&mut Self>,
291 cx: &mut Context<'_>,
292 buf: &mut ReadBuf<'_>,
293 ) -> Poll<std::io::Result<()>> {
294 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
295 }
296}
297
298impl tokio::io::AsyncWrite for BiStream {
299 fn poll_write(
300 mut self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 buf: &[u8],
303 ) -> Poll<Result<usize, std::io::Error>> {
304 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0 .0), cx, buf)
305 }
306
307 fn poll_flush(
308 mut self: Pin<&mut Self>,
309 cx: &mut Context<'_>,
310 ) -> Poll<Result<(), std::io::Error>> {
311 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0 .0), cx)
312 }
313
314 fn poll_shutdown(
315 mut self: Pin<&mut Self>,
316 cx: &mut Context<'_>,
317 ) -> Poll<Result<(), std::io::Error>> {
318 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0 .0), cx)
319 }
320}
321
322impl tokio::io::AsyncRead for BiStream {
323 fn poll_read(
324 mut self: Pin<&mut Self>,
325 cx: &mut Context<'_>,
326 buf: &mut ReadBuf<'_>,
327 ) -> Poll<std::io::Result<()>> {
328 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0 .1), cx, buf)
329 }
330}
331
332type DynFutureUniStream = dyn Future<Output = Result<SendStream, StreamOpeningError>> + Send + Sync;
333
334pub struct OpeningUniStream(Pin<Box<DynFutureUniStream>>);
338
339impl OpeningUniStream {
340 pub(crate) fn new(session_id: SessionId, quic_stream: StreamUniLocalQuic) -> Self {
341 Self(Box::pin(async move {
342 match quic_stream
343 .upgrade(StreamHeader::new_webtransport(session_id))
344 .await
345 {
346 Ok(stream) => Ok(SendStream(stream.upgrade().into_stream())),
347 Err(ProtoWriteError::NotConnected) => Err(StreamOpeningError::NotConnected),
348 Err(ProtoWriteError::Stopped) => Err(StreamOpeningError::Refused),
349 }
350 }))
351 }
352}
353
354impl Future for OpeningUniStream {
355 type Output = Result<SendStream, StreamOpeningError>;
356
357 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358 Future::poll(self.0.as_mut(), cx)
359 }
360}
361
362type DynFutureBiStream =
363 dyn Future<Output = Result<(SendStream, RecvStream), StreamOpeningError>> + Send + Sync;
364
365pub struct OpeningBiStream(Pin<Box<DynFutureBiStream>>);
369
370impl OpeningBiStream {
371 pub(crate) fn new(session_id: SessionId, quic_stream: StreamBiLocalQuic) -> Self {
372 Self(Box::pin(async move {
373 match quic_stream.upgrade().upgrade(session_id).await {
374 Ok(stream) => {
375 let stream = stream.into_stream();
376 Ok((SendStream::new(stream.0), RecvStream::new(stream.1)))
377 }
378 Err(ProtoWriteError::NotConnected) => Err(StreamOpeningError::NotConnected),
379 Err(ProtoWriteError::Stopped) => Err(StreamOpeningError::Refused),
380 }
381 }))
382 }
383}
384
385impl Future for OpeningBiStream {
386 type Output = Result<(SendStream, RecvStream), StreamOpeningError>;
387
388 #[inline(always)]
389 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390 Future::poll(self.0.as_mut(), cx)
391 }
392}