wtransport/driver/streams/
mod.rs

1use crate::driver::utils::streamid_q2w;
2use crate::driver::utils::varint_q2w;
3use crate::driver::utils::varint_w2q;
4use crate::error::ClosedStream;
5use crate::error::StreamReadError;
6use crate::error::StreamReadExactError;
7use crate::error::StreamWriteError;
8use crate::SessionId;
9use crate::StreamId;
10use crate::VarInt;
11use std::pin::Pin;
12use std::task::ready;
13use std::task::Context;
14use std::task::Poll;
15use tokio::io::ReadBuf;
16use wtransport_proto::frame::Frame;
17use wtransport_proto::session::SessionRequest;
18use wtransport_proto::stream as stream_proto;
19use wtransport_proto::stream::Stream as StreamProto;
20use wtransport_proto::stream_header::StreamHeader;
21use wtransport_proto::stream_header::StreamKind;
22
23pub type ProtoReadError = wtransport_proto::stream::IoReadError;
24pub type ProtoWriteError = wtransport_proto::stream::IoWriteError;
25
26#[derive(Debug)]
27pub struct AlreadyStop;
28
29#[derive(Debug)]
30pub struct QuicSendStream(quinn::SendStream);
31
32impl QuicSendStream {
33    #[inline(always)]
34    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, StreamWriteError> {
35        let written = self.0.write(buf).await?;
36        Ok(written)
37    }
38
39    #[inline(always)]
40    pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), StreamWriteError> {
41        self.0.write_all(buf).await?;
42        Ok(())
43    }
44
45    #[inline(always)]
46    pub async fn finish(&mut self) -> Result<(), StreamWriteError> {
47        let _ = self.0.finish();
48        let result = self.stopped().await;
49        if matches!(result, StreamWriteError::Closed) {
50            Ok(())
51        } else {
52            Err(result)
53        }
54    }
55
56    #[inline(always)]
57    pub fn set_priority(&self, priority: i32) {
58        let _ = self.0.set_priority(priority);
59    }
60
61    #[inline(always)]
62    /// # Panics
63    ///
64    /// If `reset` was called.
65    pub fn priority(&self) -> i32 {
66        self.0.priority().expect("Stream has been reset")
67    }
68
69    pub async fn stopped(&mut self) -> StreamWriteError {
70        match self.0.stopped().await {
71            Ok(None) => StreamWriteError::Closed,
72            Ok(Some(code)) => StreamWriteError::Stopped(varint_q2w(code)),
73            Err(quinn::StoppedError::ConnectionLost(_)) => StreamWriteError::NotConnected,
74            Err(quinn::StoppedError::ZeroRttRejected) => StreamWriteError::QuicProto,
75        }
76    }
77
78    #[inline(always)]
79    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
80        self.0
81            .reset(varint_w2q(error_code))
82            .map_err(|_| ClosedStream)
83    }
84
85    #[inline(always)]
86    pub fn id(&self) -> StreamId {
87        streamid_q2w(self.0.id())
88    }
89
90    #[cfg(feature = "quinn")]
91    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
92    #[inline(always)]
93    pub fn quic_stream(&self) -> &quinn::SendStream {
94        &self.0
95    }
96
97    #[cfg(feature = "quinn")]
98    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
99    #[inline(always)]
100    pub fn quic_stream_mut(&mut self) -> &mut quinn::SendStream {
101        &mut self.0
102    }
103}
104
105impl wtransport_proto::bytes::AsyncWrite for QuicSendStream {
106    #[inline(always)]
107    fn poll_write(
108        mut self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        buf: &[u8],
111    ) -> Poll<std::io::Result<usize>> {
112        tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
113    }
114}
115
116impl tokio::io::AsyncWrite for QuicSendStream {
117    #[inline(always)]
118    fn poll_write(
119        mut self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &[u8],
122    ) -> Poll<Result<usize, std::io::Error>> {
123        tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
124    }
125
126    #[inline(always)]
127    fn poll_flush(
128        mut self: Pin<&mut Self>,
129        cx: &mut Context<'_>,
130    ) -> Poll<Result<(), std::io::Error>> {
131        tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
132    }
133
134    #[inline(always)]
135    fn poll_shutdown(
136        mut self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138    ) -> Poll<Result<(), std::io::Error>> {
139        tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
140    }
141
142    #[inline(always)]
143    fn poll_write_vectored(
144        mut self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146        bufs: &[std::io::IoSlice<'_>],
147    ) -> Poll<Result<usize, std::io::Error>> {
148        tokio::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs)
149    }
150
151    fn is_write_vectored(&self) -> bool {
152        tokio::io::AsyncWrite::is_write_vectored(&self.0)
153    }
154}
155
156#[derive(Debug)]
157pub struct QuicRecvStream(quinn::RecvStream);
158
159impl QuicRecvStream {
160    #[inline(always)]
161    pub async fn read(&mut self, buf: &mut [u8]) -> Result<Option<usize>, StreamReadError> {
162        match self.0.read(buf).await? {
163            Some(read) => Ok(Some(read)),
164            None => Ok(None),
165        }
166    }
167
168    #[inline(always)]
169    pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), StreamReadExactError> {
170        self.0
171            .read_exact(buf)
172            .await
173            .map_err(|quic_error| match quic_error {
174                quinn::ReadExactError::FinishedEarly(read) => {
175                    StreamReadExactError::FinishedEarly(read)
176                }
177                quinn::ReadExactError::ReadError(read) => StreamReadExactError::Read(read.into()),
178            })
179    }
180
181    #[inline(always)]
182    pub fn stop(&mut self, error_code: VarInt) -> Result<(), AlreadyStop> {
183        self.0.stop(varint_w2q(error_code)).map_err(|_| AlreadyStop)
184    }
185
186    #[inline(always)]
187    pub fn id(&self) -> StreamId {
188        streamid_q2w(self.0.id())
189    }
190
191    #[cfg(feature = "quinn")]
192    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
193    #[inline(always)]
194    pub fn quic_stream(&self) -> &quinn::RecvStream {
195        &self.0
196    }
197
198    #[cfg(feature = "quinn")]
199    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
200    #[inline(always)]
201    pub fn quic_stream_mut(&mut self) -> &mut quinn::RecvStream {
202        &mut self.0
203    }
204}
205
206impl wtransport_proto::bytes::AsyncRead for QuicRecvStream {
207    #[inline(always)]
208    fn poll_read(
209        mut self: Pin<&mut Self>,
210        cx: &mut Context<'_>,
211        buf: &mut [u8],
212    ) -> Poll<std::io::Result<usize>> {
213        let mut buffer = ReadBuf::new(buf);
214
215        match ready!(tokio::io::AsyncRead::poll_read(
216            Pin::new(&mut self.0),
217            cx,
218            &mut buffer
219        )) {
220            Ok(()) => Poll::Ready(Ok(buffer.filled().len())),
221            Err(io_error) => Poll::Ready(Err(io_error)),
222        }
223    }
224}
225
226impl tokio::io::AsyncRead for QuicRecvStream {
227    #[inline(always)]
228    fn poll_read(
229        mut self: Pin<&mut Self>,
230        cx: &mut Context<'_>,
231        buf: &mut ReadBuf<'_>,
232    ) -> Poll<std::io::Result<()>> {
233        tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
234    }
235}
236
237#[derive(Debug)]
238pub struct Stream<S, P> {
239    stream: S,
240    proto: P,
241}
242
243pub mod biremote {
244    use super::*;
245
246    pub type StreamBiRemoteQuic =
247        Stream<(QuicSendStream, QuicRecvStream), stream_proto::biremote::StreamBiRemoteQuic>;
248
249    pub type StreamBiRemoteH3 =
250        Stream<(QuicSendStream, QuicRecvStream), stream_proto::biremote::StreamBiRemoteH3>;
251
252    pub type StreamBiRemoteWT =
253        Stream<(QuicSendStream, QuicRecvStream), stream_proto::biremote::StreamBiRemoteWT>;
254
255    impl StreamBiRemoteQuic {
256        pub async fn accept_bi(quic_connection: &quinn::Connection) -> Option<Self> {
257            let stream = quic_connection.accept_bi().await.ok()?;
258            Some(Self {
259                stream: (QuicSendStream(stream.0), QuicRecvStream(stream.1)),
260                proto: StreamProto::accept_bi(),
261            })
262        }
263
264        pub fn upgrade(self) -> StreamBiRemoteH3 {
265            StreamBiRemoteH3 {
266                stream: self.stream,
267                proto: self.proto.upgrade(),
268            }
269        }
270
271        #[inline(always)]
272        pub fn id(&self) -> StreamId {
273            self.stream.0.id()
274        }
275    }
276
277    impl StreamBiRemoteH3 {
278        pub async fn read_frame<'a>(&mut self) -> Result<Frame<'a>, ProtoReadError> {
279            self.proto.read_frame_async(&mut self.stream.1).await
280        }
281
282        pub fn stop(&mut self, error_code: VarInt) -> Result<(), AlreadyStop> {
283            self.stream.1.stop(error_code)
284        }
285
286        pub fn upgrade(self, session_id: SessionId) -> StreamBiRemoteWT {
287            StreamBiRemoteWT {
288                stream: self.stream,
289                proto: self.proto.upgrade(session_id),
290            }
291        }
292
293        pub fn id(&self) -> StreamId {
294            self.stream.0.id()
295        }
296
297        pub fn into_session(self, session_request: SessionRequest) -> session::StreamSession {
298            session::StreamSession {
299                stream: self.stream,
300                proto: self.proto.into_session(session_request),
301            }
302        }
303    }
304
305    impl StreamBiRemoteWT {
306        #[inline(always)]
307        pub fn session_id(&self) -> SessionId {
308            self.proto.session_id()
309        }
310
311        #[inline(always)]
312        pub fn id(&self) -> StreamId {
313            self.stream.0.id()
314        }
315
316        #[inline(always)]
317        pub fn into_stream(self) -> (QuicSendStream, QuicRecvStream) {
318            self.stream
319        }
320    }
321}
322
323pub mod bilocal {
324    use super::*;
325
326    pub type StreamBiLocalQuic =
327        Stream<(QuicSendStream, QuicRecvStream), stream_proto::bilocal::StreamBiLocalQuic>;
328
329    pub type StreamBiLocalH3 =
330        Stream<(QuicSendStream, QuicRecvStream), stream_proto::bilocal::StreamBiLocalH3>;
331
332    pub type StreamBiLocalWT =
333        Stream<(QuicSendStream, QuicRecvStream), stream_proto::bilocal::StreamBiLocalWT>;
334
335    impl StreamBiLocalQuic {
336        pub async fn open_bi(quic_connection: &quinn::Connection) -> Option<Self> {
337            let stream = quic_connection.open_bi().await.ok()?;
338            Some(Self {
339                stream: (QuicSendStream(stream.0), QuicRecvStream(stream.1)),
340                proto: StreamProto::open_bi(),
341            })
342        }
343
344        pub fn upgrade(self) -> StreamBiLocalH3 {
345            StreamBiLocalH3 {
346                stream: self.stream,
347                proto: self.proto.upgrade(),
348            }
349        }
350    }
351
352    impl StreamBiLocalH3 {
353        pub async fn upgrade(
354            mut self,
355            session_id: SessionId,
356        ) -> Result<StreamBiLocalWT, ProtoWriteError> {
357            let proto = self
358                .proto
359                .upgrade_async(session_id, &mut self.stream.0)
360                .await?;
361
362            Ok(StreamBiLocalWT {
363                stream: self.stream,
364                proto,
365            })
366        }
367
368        pub fn into_session(self, session_request: SessionRequest) -> session::StreamSession {
369            session::StreamSession {
370                stream: self.stream,
371                proto: self.proto.into_session(session_request),
372            }
373        }
374    }
375
376    impl StreamBiLocalWT {
377        pub fn into_stream(self) -> (QuicSendStream, QuicRecvStream) {
378            self.stream
379        }
380    }
381}
382
383pub mod uniremote {
384    use super::*;
385
386    pub type StreamUniRemoteQuic =
387        Stream<QuicRecvStream, stream_proto::uniremote::StreamUniRemoteQuic>;
388
389    pub type StreamUniRemoteH3 = Stream<QuicRecvStream, stream_proto::uniremote::StreamUniRemoteH3>;
390
391    pub type StreamUniRemoteWT = Stream<QuicRecvStream, stream_proto::uniremote::StreamUniRemoteWT>;
392
393    impl StreamUniRemoteQuic {
394        pub async fn accept_uni(quic_connection: &quinn::Connection) -> Option<Self> {
395            let stream = quic_connection.accept_uni().await.ok()?;
396            Some(Self {
397                stream: QuicRecvStream(stream),
398                proto: StreamProto::accept_uni(),
399            })
400        }
401
402        pub async fn upgrade(mut self) -> Result<StreamUniRemoteH3, ProtoReadError> {
403            let proto = self.proto.upgrade_async(&mut self.stream).await?;
404            Ok(StreamUniRemoteH3 {
405                stream: self.stream,
406                proto,
407            })
408        }
409
410        #[inline(always)]
411        pub fn id(&self) -> StreamId {
412            self.stream.id()
413        }
414    }
415
416    impl StreamUniRemoteH3 {
417        pub async fn read_frame<'a>(&mut self) -> Result<Frame<'a>, ProtoReadError> {
418            self.proto.read_frame_async(&mut self.stream).await
419        }
420
421        pub fn kind(&self) -> StreamKind {
422            self.proto.kind()
423        }
424
425        pub fn upgrade(self) -> StreamUniRemoteWT {
426            StreamUniRemoteWT {
427                stream: self.stream,
428                proto: self.proto.upgrade(),
429            }
430        }
431
432        pub fn stream_mut(&mut self) -> &mut QuicRecvStream {
433            &mut self.stream
434        }
435    }
436
437    impl StreamUniRemoteWT {
438        #[inline(always)]
439        pub fn session_id(&self) -> SessionId {
440            self.proto.session_id()
441        }
442
443        #[inline(always)]
444        pub fn id(&self) -> StreamId {
445            self.stream.id()
446        }
447
448        #[inline(always)]
449        pub fn into_stream(self) -> QuicRecvStream {
450            self.stream
451        }
452    }
453}
454
455pub mod unilocal {
456    use super::*;
457
458    pub type StreamUniLocalQuic =
459        Stream<QuicSendStream, stream_proto::unilocal::StreamUniLocalQuic>;
460
461    pub type StreamUniLocalH3 = Stream<QuicSendStream, stream_proto::unilocal::StreamUniLocalH3>;
462
463    pub type StreamUniLocalWT = Stream<QuicSendStream, stream_proto::unilocal::StreamUniLocalWT>;
464
465    impl StreamUniLocalQuic {
466        pub async fn open_uni(quic_connection: &quinn::Connection) -> Option<Self> {
467            let stream = quic_connection.open_uni().await.ok()?;
468            Some(Self {
469                stream: QuicSendStream(stream),
470                proto: StreamProto::open_uni(),
471            })
472        }
473
474        pub async fn upgrade(
475            mut self,
476            stream_header: StreamHeader,
477        ) -> Result<StreamUniLocalH3, ProtoWriteError> {
478            let proto = self
479                .proto
480                .upgrade_async(stream_header, &mut self.stream)
481                .await?;
482
483            Ok(StreamUniLocalH3 {
484                stream: self.stream,
485                proto,
486            })
487        }
488    }
489
490    impl StreamUniLocalH3 {
491        pub async fn write_frame(&mut self, frame: Frame<'_>) -> Result<(), ProtoWriteError> {
492            self.proto.write_frame_async(frame, &mut self.stream).await
493        }
494
495        pub fn kind(&self) -> StreamKind {
496            self.proto.kind()
497        }
498
499        pub async fn stopped(&mut self) -> StreamWriteError {
500            self.stream.stopped().await
501        }
502
503        pub fn upgrade(self) -> StreamUniLocalWT {
504            StreamUniLocalWT {
505                stream: self.stream,
506                proto: self.proto.upgrade(),
507            }
508        }
509    }
510
511    impl StreamUniLocalWT {
512        pub fn into_stream(self) -> QuicSendStream {
513            self.stream
514        }
515    }
516}
517
518pub mod session {
519    use super::*;
520
521    pub type StreamSession =
522        Stream<(QuicSendStream, QuicRecvStream), stream_proto::session::StreamSession>;
523
524    impl StreamSession {
525        pub async fn read_frame<'a>(&mut self) -> Result<Frame<'a>, ProtoReadError> {
526            self.proto.read_frame_async(&mut self.stream.1).await
527        }
528
529        pub async fn write_frame(&mut self, frame: Frame<'_>) -> Result<(), ProtoWriteError> {
530            self.proto
531                .write_frame_async(frame, &mut self.stream.0)
532                .await
533        }
534
535        pub fn stop(&mut self, error_code: VarInt) -> Result<(), AlreadyStop> {
536            self.stream.1.stop(error_code)
537        }
538
539        pub fn id(&self) -> StreamId {
540            self.stream.0.id()
541        }
542
543        pub fn session_id(&self) -> SessionId {
544            SessionId::try_from_session_stream(self.id()).expect("Session stream must be valid")
545        }
546
547        pub fn request(&self) -> &SessionRequest {
548            self.proto.request()
549        }
550
551        pub async fn finish(&mut self) {
552            let _ = self.stream.0.finish().await;
553        }
554
555        pub fn reset(&mut self, error_code: VarInt) {
556            let _ = self.stream.0.reset(error_code);
557        }
558    }
559}
560
561impl From<quinn::WriteError> for StreamWriteError {
562    fn from(error: quinn::WriteError) -> Self {
563        match error {
564            quinn::WriteError::Stopped(code) => StreamWriteError::Stopped(varint_q2w(code)),
565            quinn::WriteError::ConnectionLost(_) | quinn::WriteError::ClosedStream => {
566                StreamWriteError::NotConnected
567            }
568            quinn::WriteError::ZeroRttRejected => StreamWriteError::QuicProto,
569        }
570    }
571}
572
573impl From<quinn::ReadError> for StreamReadError {
574    fn from(error: quinn::ReadError) -> Self {
575        match error {
576            quinn::ReadError::Reset(code) => StreamReadError::Reset(varint_q2w(code)),
577            quinn::ReadError::ConnectionLost(_) | quinn::ReadError::ClosedStream => {
578                StreamReadError::NotConnected
579            }
580            quinn::ReadError::IllegalOrderedRead => StreamReadError::QuicProto,
581            quinn::ReadError::ZeroRttRejected => StreamReadError::QuicProto,
582        }
583    }
584}
585
586pub mod connect;
587pub mod qpack;
588pub mod settings;