wtransport_proto/
stream.rs

1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::error::ErrorCode;
7use crate::frame;
8use crate::frame::Frame;
9use crate::frame::FrameKind;
10use crate::ids::SessionId;
11use crate::session::SessionRequest;
12use crate::stream_header;
13use crate::stream_header::StreamHeader;
14use crate::stream_header::StreamKind;
15
16#[cfg(feature = "async")]
17use crate::bytes::AsyncRead;
18
19#[cfg(feature = "async")]
20use crate::bytes::AsyncWrite;
21
22#[cfg(feature = "async")]
23use crate::bytes;
24
25/// An error during stream I/O read operation.
26#[cfg(feature = "async")]
27#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
28#[derive(Debug, thiserror::Error)]
29pub enum IoReadError {
30    /// Error on HTTP3 protocol.
31    #[error(transparent)]
32    H3(ErrorCode),
33
34    /// Error due to I/O operation.
35    #[error(transparent)]
36    IO(bytes::IoReadError),
37}
38
39/// An error during stream I/O write operation.
40#[cfg(feature = "async")]
41pub type IoWriteError = bytes::IoWriteError;
42
43/// A QUIC/HTTP3/WebTransport stream.
44#[derive(Debug)]
45pub struct Stream<K, S> {
46    kind: K,
47    stage: S,
48}
49
50/// Bidirectional remote stream implementations.
51pub mod biremote {
52    use super::*;
53    use types::*;
54
55    /// QUIC bidirectional remote stream.
56    pub type StreamBiRemoteQuic = Stream<BiRemote, Quic>;
57
58    /// HTTP3 bidirectional remote stream.
59    pub type StreamBiRemoteH3 = Stream<BiRemote, H3>;
60
61    /// WebTransport bidirectional remote stream.
62    pub type StreamBiRemoteWT = Stream<BiRemote, WT>;
63
64    impl StreamBiRemoteQuic {
65        /// Creates a new remote-initialized bidirectional stream.
66        pub fn accept_bi() -> Self {
67            Self {
68                kind: BiRemote::default(),
69                stage: Quic,
70            }
71        }
72
73        /// Upgrades to an HTTP3 stream.
74        pub fn upgrade(self) -> StreamBiRemoteH3 {
75            StreamBiRemoteH3 {
76                kind: self.kind,
77                stage: H3::new(None),
78            }
79        }
80    }
81
82    impl StreamBiRemoteH3 {
83        /// See [`Frame::read`].
84        pub fn read_frame<'a, R>(
85            &mut self,
86            bytes_reader: &mut R,
87        ) -> Result<Option<Frame<'a>>, ErrorCode>
88        where
89            R: BytesReader<'a>,
90        {
91            loop {
92                match Frame::read(bytes_reader) {
93                    Ok(Some(frame)) => {
94                        return Ok(Some(self.validate_frame(frame)?));
95                    }
96                    Ok(None) => {
97                        return Ok(None);
98                    }
99                    Err(frame::ParseError::UnknownFrame) => {
100                        continue;
101                    }
102                    Err(frame::ParseError::InvalidSessionId) => {
103                        return Err(ErrorCode::Id);
104                    }
105                    Err(frame::ParseError::PayloadTooBig) => {
106                        return Err(ErrorCode::ExcessiveLoad);
107                    }
108                }
109            }
110        }
111
112        /// See [`Frame::read_async`].
113        #[cfg(feature = "async")]
114        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
115        pub async fn read_frame_async<'a, R>(
116            &mut self,
117            reader: &mut R,
118        ) -> Result<Frame<'a>, IoReadError>
119        where
120            R: AsyncRead + Unpin + ?Sized,
121        {
122            loop {
123                match Frame::read_async(reader).await {
124                    Ok(frame) => {
125                        return self.validate_frame(frame).map_err(IoReadError::H3);
126                    }
127                    Err(frame::IoReadError::Parse(frame::ParseError::UnknownFrame)) => {
128                        continue;
129                    }
130                    Err(frame::IoReadError::Parse(frame::ParseError::InvalidSessionId)) => {
131                        return Err(IoReadError::H3(ErrorCode::Id));
132                    }
133                    Err(frame::IoReadError::Parse(frame::ParseError::PayloadTooBig)) => {
134                        return Err(IoReadError::H3(ErrorCode::ExcessiveLoad));
135                    }
136                    Err(frame::IoReadError::IO(io_error)) => {
137                        if matches!(io_error, bytes::IoReadError::UnexpectedFin) {
138                            return Err(IoReadError::H3(ErrorCode::Frame));
139                        }
140
141                        return Err(IoReadError::IO(io_error));
142                    }
143                }
144            }
145        }
146
147        /// See [`Frame::read_from_buffer`].
148        pub fn read_frame_from_buffer<'a>(
149            &mut self,
150            buffer_reader: &mut BufferReader<'a>,
151        ) -> Result<Option<Frame<'a>>, ErrorCode> {
152            let mut buffer_reader_child = buffer_reader.child();
153
154            match self.read_frame(&mut *buffer_reader_child)? {
155                Some(frame) => {
156                    buffer_reader_child.commit();
157                    Ok(Some(frame))
158                }
159                None => Ok(None),
160            }
161        }
162
163        /// See [`Frame::write`].
164        pub fn write_frame<W>(&self, frame: Frame, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
165        where
166            W: BytesWriter,
167        {
168            frame.write(bytes_writer)
169        }
170
171        /// See [`Frame::write_async`].
172        #[cfg(feature = "async")]
173        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
174        pub async fn write_frame_async<W>(
175            &self,
176            frame: Frame<'_>,
177            writer: &mut W,
178        ) -> Result<(), IoWriteError>
179        where
180            W: AsyncWrite + Unpin + ?Sized,
181        {
182            frame.write_async(writer).await
183        }
184
185        /// See [`Frame::write_to_buffer`].
186        pub fn write_frame_to_buffer(
187            &self,
188            frame: Frame,
189            buffer_writer: &mut BufferWriter,
190        ) -> Result<(), EndOfBuffer> {
191            frame.write_to_buffer(buffer_writer)
192        }
193
194        /// Upgrades to a WebTransport stream.
195        ///
196        /// **Note**: upgrade should be performed only when [`FrameKind::WebTransport`] is
197        /// received as first frame on this HTTP3 stream.
198        pub fn upgrade(self, session_id: SessionId) -> StreamBiRemoteWT {
199            StreamBiRemoteWT {
200                kind: self.kind,
201                stage: WT::new(session_id),
202            }
203        }
204
205        /// Converts the stream into a `StreamSession`.
206        pub fn into_session(self, session_request: SessionRequest) -> session::StreamSession {
207            session::StreamSession {
208                kind: Bi,
209                stage: Session::new(session_request),
210            }
211        }
212
213        fn validate_frame<'a>(&mut self, frame: Frame<'a>) -> Result<Frame<'a>, ErrorCode> {
214            let first_frame_done = self.stage.set_first_frame();
215
216            match frame.kind() {
217                FrameKind::Data => Ok(frame),
218                FrameKind::Headers => Ok(frame),
219                FrameKind::Settings => Err(ErrorCode::FrameUnexpected),
220                FrameKind::WebTransport => {
221                    if !first_frame_done {
222                        Ok(frame)
223                    } else {
224                        Err(ErrorCode::Frame)
225                    }
226                }
227                FrameKind::Exercise(_) => Ok(frame),
228            }
229        }
230    }
231
232    impl StreamBiRemoteWT {
233        /// Returns the [`SessionId`] associated with this stream.
234        #[inline(always)]
235        pub fn session_id(&self) -> SessionId {
236            self.stage.session_id()
237        }
238    }
239}
240
241/// Bidirectional local stream implementations.
242pub mod bilocal {
243    use super::*;
244    use types::*;
245
246    /// QUIC bidirectional local stream.
247    pub type StreamBiLocalQuic = Stream<BiLocal, Quic>;
248
249    /// HTTP3 bidirectional local stream.
250    pub type StreamBiLocalH3 = Stream<BiLocal, H3>;
251
252    /// WebTransport bidirectional local stream.
253    pub type StreamBiLocalWT = Stream<BiLocal, WT>;
254
255    impl StreamBiLocalQuic {
256        /// Creates a new locally-initialized bidirectional stream.
257        pub fn open_bi() -> Self {
258            Self {
259                kind: BiLocal::default(),
260                stage: Quic,
261            }
262        }
263
264        /// Upgrades to an HTTP3 stream.
265        pub fn upgrade(self) -> StreamBiLocalH3 {
266            StreamBiLocalH3 {
267                kind: self.kind,
268                stage: H3::new(None),
269            }
270        }
271    }
272
273    impl StreamBiLocalH3 {
274        /// See [`Frame::read`].
275        pub fn read_frame<'a, R>(
276            &self,
277            bytes_reader: &mut R,
278        ) -> Result<Option<Frame<'a>>, ErrorCode>
279        where
280            R: BytesReader<'a>,
281        {
282            loop {
283                match Frame::read(bytes_reader) {
284                    Ok(Some(frame)) => {
285                        return Ok(Some(self.validate_frame(frame)?));
286                    }
287                    Ok(None) => {
288                        return Ok(None);
289                    }
290                    Err(frame::ParseError::UnknownFrame) => {
291                        continue;
292                    }
293                    Err(frame::ParseError::InvalidSessionId) => {
294                        return Err(ErrorCode::Id);
295                    }
296                    Err(frame::ParseError::PayloadTooBig) => {
297                        return Err(ErrorCode::ExcessiveLoad);
298                    }
299                }
300            }
301        }
302
303        /// See [`Frame::read_async`].
304        #[cfg(feature = "async")]
305        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
306        pub async fn read_frame_async<'a, R>(
307            &self,
308            reader: &mut R,
309        ) -> Result<Frame<'a>, IoReadError>
310        where
311            R: AsyncRead + Unpin + ?Sized,
312        {
313            loop {
314                match Frame::read_async(reader).await {
315                    Ok(frame) => {
316                        return self.validate_frame(frame).map_err(IoReadError::H3);
317                    }
318                    Err(frame::IoReadError::Parse(frame::ParseError::UnknownFrame)) => {
319                        continue;
320                    }
321                    Err(frame::IoReadError::Parse(frame::ParseError::InvalidSessionId)) => {
322                        return Err(IoReadError::H3(ErrorCode::Id));
323                    }
324                    Err(frame::IoReadError::Parse(frame::ParseError::PayloadTooBig)) => {
325                        return Err(IoReadError::H3(ErrorCode::ExcessiveLoad));
326                    }
327                    Err(frame::IoReadError::IO(io_error)) => {
328                        if matches!(io_error, bytes::IoReadError::UnexpectedFin) {
329                            return Err(IoReadError::H3(ErrorCode::Frame));
330                        }
331
332                        return Err(IoReadError::IO(io_error));
333                    }
334                }
335            }
336        }
337
338        /// See [`Frame::read_from_buffer`].
339        pub fn read_frame_from_buffer<'a>(
340            &self,
341            buffer_reader: &mut BufferReader<'a>,
342        ) -> Result<Option<Frame<'a>>, ErrorCode> {
343            let mut buffer_reader_child = buffer_reader.child();
344
345            match self.read_frame(&mut *buffer_reader_child)? {
346                Some(frame) => {
347                    buffer_reader_child.commit();
348                    Ok(Some(frame))
349                }
350                None => Ok(None),
351            }
352        }
353
354        /// See [`Frame::write`].
355        ///
356        /// # Panics
357        ///
358        /// Panics if [`FrameKind::WebTransport`] (use `upgrade` for that).
359        pub fn write_frame<W>(
360            &mut self,
361            frame: Frame,
362            bytes_writer: &mut W,
363        ) -> Result<(), EndOfBuffer>
364        where
365            W: BytesWriter,
366        {
367            assert!(!matches!(frame.kind(), FrameKind::WebTransport));
368            frame.write(bytes_writer)?;
369            self.stage.set_first_frame();
370            Ok(())
371        }
372
373        /// See [`Frame::write_async`].
374        ///
375        /// # Panics
376        ///
377        /// Panics if [`FrameKind::WebTransport`] (use `upgrade` for that).
378        #[cfg(feature = "async")]
379        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
380        pub async fn write_frame_async<W>(
381            &mut self,
382            frame: Frame<'_>,
383            writer: &mut W,
384        ) -> Result<(), IoWriteError>
385        where
386            W: AsyncWrite + Unpin + ?Sized,
387        {
388            assert!(!matches!(frame.kind(), FrameKind::WebTransport));
389            frame.write_async(writer).await?;
390            self.stage.set_first_frame();
391            Ok(())
392        }
393
394        /// See [`Frame::write_to_buffer`].
395        ///
396        /// # Panics
397        ///
398        /// Panics if [`FrameKind::WebTransport`] (use `upgrade` for that).
399        pub fn write_frame_to_buffer(
400            &mut self,
401            frame: Frame,
402            buffer_writer: &mut BufferWriter,
403        ) -> Result<(), EndOfBuffer> {
404            assert!(!matches!(frame.kind(), FrameKind::WebTransport));
405            frame.write_to_buffer(buffer_writer)?;
406            self.stage.set_first_frame();
407            Ok(())
408        }
409
410        /// Upgrades to a WebTransport stream.
411        ///
412        /// # Panics
413        ///
414        /// * Panics if any other I/O operation has been performed on this stream before upgrade.
415        /// * Panics if `bytes_writer` does not have enough capacity. See [`Self::upgrade_size`].
416        pub fn upgrade<W>(mut self, session_id: SessionId, bytes_writer: &mut W) -> StreamBiLocalWT
417        where
418            W: BytesWriter,
419        {
420            assert!(!self.stage.set_first_frame());
421
422            Frame::new_webtransport(session_id)
423                .write(bytes_writer)
424                .expect("Upgrade failed because buffer too short");
425
426            StreamBiLocalWT {
427                kind: self.kind,
428                stage: WT::new(session_id),
429            }
430        }
431
432        /// Upgrades to a WebTransport stream.
433        ///
434        /// # Panics
435        ///
436        /// * Panics if any other I/O operation has been performed on this stream before upgrade.
437        #[cfg(feature = "async")]
438        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
439        pub async fn upgrade_async<W>(
440            mut self,
441            session_id: SessionId,
442            writer: &mut W,
443        ) -> Result<StreamBiLocalWT, IoWriteError>
444        where
445            W: AsyncWrite + Unpin + ?Sized,
446        {
447            assert!(!self.stage.set_first_frame());
448
449            Frame::new_webtransport(session_id)
450                .write_async(writer)
451                .await?;
452
453            Ok(StreamBiLocalWT {
454                kind: self.kind,
455                stage: WT::new(session_id),
456            })
457        }
458
459        /// Returns the needed capacity for upgrade via [`Self::upgrade`].
460        pub fn upgrade_size(&self, session_id: SessionId) -> usize {
461            Frame::new_webtransport(session_id).write_size()
462        }
463
464        /// Converts the stream into a `StreamSession`.
465        pub fn into_session(self, session_request: SessionRequest) -> session::StreamSession {
466            session::StreamSession {
467                kind: Bi,
468                stage: Session::new(session_request),
469            }
470        }
471
472        fn validate_frame<'a>(&self, frame: Frame<'a>) -> Result<Frame<'a>, ErrorCode> {
473            match frame.kind() {
474                FrameKind::Data => Ok(frame),
475                FrameKind::Headers => Ok(frame),
476                FrameKind::Settings => Err(ErrorCode::FrameUnexpected),
477                FrameKind::WebTransport => Err(ErrorCode::FrameUnexpected),
478                FrameKind::Exercise(_) => Ok(frame),
479            }
480        }
481    }
482
483    impl StreamBiLocalWT {
484        /// Returns the [`SessionId`] associated with this stream.
485        #[inline(always)]
486        pub fn session_id(&self) -> SessionId {
487            self.stage.session_id()
488        }
489    }
490}
491
492/// unidirectional remote stream implementations.
493pub mod uniremote {
494    use super::*;
495    use types::*;
496
497    /// A result of attempt to upgrade a `UniRemote` stream.
498    pub enum MaybeUpgradeH3 {
499        /// Stream cannot be upgraded. Not enough data.
500        Quic(StreamUniRemoteQuic),
501
502        /// Stream upgraded to HTTP3.
503        H3(StreamUniRemoteH3),
504    }
505
506    /// QUIC unidirectional remote stream.
507    pub type StreamUniRemoteQuic = Stream<UniRemote, Quic>;
508
509    /// HTTP3 unidirectional remote stream.
510    pub type StreamUniRemoteH3 = Stream<UniRemote, H3>;
511
512    /// WebTransport unidirectional remote stream.
513    pub type StreamUniRemoteWT = Stream<UniRemote, WT>;
514
515    impl StreamUniRemoteQuic {
516        /// Creates a new remote-initialized unidirectional stream.
517        pub fn accept_uni() -> Self {
518            Self {
519                kind: UniRemote::default(),
520                stage: Quic,
521            }
522        }
523
524        /// Upgrades to an HTTP3 stream.
525        ///
526        /// Because `bytes_reader` could not contain all required data, this behaves more like
527        /// an attempt of upgrading.
528        ///
529        /// In case there are no enough information, [`MaybeUpgradeH3::Quic`] (i.e, `self`)
530        /// will be returned.
531        ///
532        /// If the stream type is unknown [`ErrorCode::StreamCreation`] is returned.
533        /// In that case, MUST NOT consider unknown stream types to be a connection error of any kind.
534        pub fn upgrade<'a, R>(self, bytes_reader: &mut R) -> Result<MaybeUpgradeH3, ErrorCode>
535        where
536            R: BytesReader<'a>,
537        {
538            match StreamHeader::read(bytes_reader) {
539                Ok(Some(stream_header)) => Ok(MaybeUpgradeH3::H3(StreamUniRemoteH3 {
540                    kind: self.kind,
541                    stage: H3::new(Some(stream_header)),
542                })),
543                Ok(None) => Ok(MaybeUpgradeH3::Quic(self)),
544                Err(stream_header::ParseError::UnknownStream) => Err(ErrorCode::StreamCreation),
545                Err(stream_header::ParseError::InvalidSessionId) => Err(ErrorCode::Id),
546            }
547        }
548
549        /// Upgrades to an HTTP3 stream.
550        #[cfg(feature = "async")]
551        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
552        pub async fn upgrade_async<R>(
553            self,
554            reader: &mut R,
555        ) -> Result<StreamUniRemoteH3, IoReadError>
556        where
557            R: AsyncRead + Unpin + ?Sized,
558        {
559            match StreamHeader::read_async(reader).await {
560                Ok(stream_header) => Ok(StreamUniRemoteH3 {
561                    kind: self.kind,
562                    stage: H3::new(Some(stream_header)),
563                }),
564
565                Err(stream_header::IoReadError::Parse(
566                    stream_header::ParseError::UnknownStream,
567                )) => Err(IoReadError::H3(ErrorCode::StreamCreation)),
568
569                Err(stream_header::IoReadError::Parse(
570                    stream_header::ParseError::InvalidSessionId,
571                )) => Err(IoReadError::H3(ErrorCode::Id)),
572
573                Err(stream_header::IoReadError::IO(io_error)) => {
574                    if matches!(io_error, bytes::IoReadError::UnexpectedFin) {
575                        // TODO(bfesta): Check if this scenario use Frame code error
576                        Err(IoReadError::H3(ErrorCode::Frame))
577                    } else {
578                        Err(IoReadError::IO(io_error))
579                    }
580                }
581            }
582        }
583    }
584
585    impl StreamUniRemoteH3 {
586        /// See [`Frame::read`].
587        ///
588        /// # Panics
589        ///
590        /// Panics if the stream kind is [`StreamKind::WebTransport`]. In that case, use `upgrade` method.
591        pub fn read_frame<'a, R>(
592            &mut self,
593            bytes_reader: &mut R,
594        ) -> Result<Option<Frame<'a>>, ErrorCode>
595        where
596            R: BytesReader<'a>,
597        {
598            assert!(!matches!(self.kind(), StreamKind::WebTransport));
599
600            loop {
601                match Frame::read(bytes_reader) {
602                    Ok(Some(frame)) => {
603                        return Ok(Some(self.validate_frame(frame)?));
604                    }
605                    Ok(None) => {
606                        return Ok(None);
607                    }
608                    Err(frame::ParseError::UnknownFrame) => {
609                        continue;
610                    }
611                    Err(frame::ParseError::InvalidSessionId) => {
612                        return Err(ErrorCode::Id);
613                    }
614                    Err(frame::ParseError::PayloadTooBig) => {
615                        return Err(ErrorCode::ExcessiveLoad);
616                    }
617                }
618            }
619        }
620
621        /// See [`Frame::read_async`].
622        ///
623        /// # Panics
624        ///
625        /// Panics if the stream kind is [`StreamKind::WebTransport`]. In that case, use `upgrade` method.
626        #[cfg(feature = "async")]
627        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
628        pub async fn read_frame_async<'a, R>(
629            &mut self,
630            reader: &mut R,
631        ) -> Result<Frame<'a>, IoReadError>
632        where
633            R: AsyncRead + Unpin + ?Sized,
634        {
635            assert!(!matches!(self.kind(), StreamKind::WebTransport));
636
637            loop {
638                match Frame::read_async(reader).await {
639                    Ok(frame) => {
640                        return self.validate_frame(frame).map_err(IoReadError::H3);
641                    }
642                    Err(frame::IoReadError::Parse(frame::ParseError::UnknownFrame)) => {
643                        continue;
644                    }
645                    Err(frame::IoReadError::Parse(frame::ParseError::InvalidSessionId)) => {
646                        return Err(IoReadError::H3(ErrorCode::Id));
647                    }
648                    Err(frame::IoReadError::Parse(frame::ParseError::PayloadTooBig)) => {
649                        return Err(IoReadError::H3(ErrorCode::ExcessiveLoad));
650                    }
651                    Err(frame::IoReadError::IO(io_error)) => {
652                        if matches!(io_error, bytes::IoReadError::UnexpectedFin) {
653                            return Err(IoReadError::H3(ErrorCode::Frame));
654                        }
655
656                        return Err(IoReadError::IO(io_error));
657                    }
658                }
659            }
660        }
661
662        /// See [`Frame::read_from_buffer`].
663        ///
664        /// # Panics
665        ///
666        /// Panics if the stream kind is [`StreamKind::WebTransport`]. In that case, use `upgrade` method.
667        pub fn read_frame_from_buffer<'a>(
668            &mut self,
669            buffer_reader: &mut BufferReader<'a>,
670        ) -> Result<Option<Frame<'a>>, ErrorCode> {
671            let mut buffer_reader_child = buffer_reader.child();
672
673            match self.read_frame(&mut *buffer_reader_child)? {
674                Some(frame) => {
675                    buffer_reader_child.commit();
676                    Ok(Some(frame))
677                }
678                None => Ok(None),
679            }
680        }
681
682        /// Upgrades to a WebTransport stream.
683        ///
684        /// # Panics
685        ///
686        /// Panics if the stream kind is not [`StreamKind::WebTransport`].
687        pub fn upgrade(self) -> StreamUniRemoteWT {
688            assert!(matches!(self.kind(), StreamKind::WebTransport));
689
690            StreamUniRemoteWT {
691                kind: self.kind,
692                stage: WT::new(
693                    self.stage
694                        .stream_header()
695                        .expect("Unistream has header")
696                        .session_id()
697                        .expect("WebTransport type has session id"),
698                ),
699            }
700        }
701
702        /// Returns the [`StreamKind`] associated with the stream.
703        pub fn kind(&self) -> StreamKind {
704            self.stage
705                .stream_header()
706                .expect("Unistream has header")
707                .kind()
708        }
709
710        /// Returns the [`SessionId`] if stream is [`StreamKind::WebTransport`],
711        /// otherwise returns [`None`].
712        pub fn session_id(&self) -> Option<SessionId> {
713            self.stage
714                .stream_header()
715                .expect("Unistream has header")
716                .session_id()
717        }
718
719        fn validate_frame<'a>(&mut self, frame: Frame<'a>) -> Result<Frame<'a>, ErrorCode> {
720            match frame.kind() {
721                FrameKind::Data => Err(ErrorCode::FrameUnexpected),
722                FrameKind::Headers => Err(ErrorCode::FrameUnexpected),
723                FrameKind::Settings => Ok(frame),
724                FrameKind::WebTransport => Err(ErrorCode::FrameUnexpected),
725                FrameKind::Exercise(_) => Ok(frame),
726            }
727        }
728    }
729
730    impl StreamUniRemoteWT {
731        /// Returns the [`SessionId`] associated with this stream.
732        #[inline(always)]
733        pub fn session_id(&self) -> SessionId {
734            self.stage.session_id()
735        }
736    }
737}
738
739/// Unidirectional local stream implementations.
740pub mod unilocal {
741    use super::*;
742    use types::*;
743
744    /// QUIC unidirectional remote stream.
745    pub type StreamUniLocalQuic = Stream<UniLocal, Quic>;
746
747    /// HTTP3 unidirectional remote stream.
748    pub type StreamUniLocalH3 = Stream<UniLocal, H3>;
749
750    /// WebTransport unidirectional remote stream.
751    pub type StreamUniLocalWT = Stream<UniLocal, WT>;
752
753    impl StreamUniLocalQuic {
754        /// Creates a new locally-initialized unidirectional stream.
755        pub fn open_uni() -> Self {
756            Self {
757                kind: UniLocal::default(),
758                stage: Quic,
759            }
760        }
761
762        /// Upgrades to an HTTP3 stream.
763        ///
764        /// # Panics
765        ///
766        /// Panics if `bytes_writer` does not have enough capacity to write
767        /// the `stream_header`.
768        /// Check it with [`Self::upgrade_size`].
769        pub fn upgrade<W>(
770            self,
771            stream_header: StreamHeader,
772            bytes_writer: &mut W,
773        ) -> StreamUniLocalH3
774        where
775            W: BytesWriter,
776        {
777            stream_header
778                .write(bytes_writer)
779                .expect("Upgrade failed because buffer too short");
780
781            StreamUniLocalH3 {
782                kind: self.kind,
783                stage: H3::new(Some(stream_header)),
784            }
785        }
786
787        /// Upgrades to an HTTP3 stream.
788        #[cfg(feature = "async")]
789        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
790        pub async fn upgrade_async<W>(
791            self,
792            stream_header: StreamHeader,
793            writer: &mut W,
794        ) -> Result<StreamUniLocalH3, IoWriteError>
795        where
796            W: AsyncWrite + Unpin + ?Sized,
797        {
798            stream_header.write_async(writer).await?;
799
800            Ok(StreamUniLocalH3 {
801                kind: self.kind,
802                stage: H3::new(Some(stream_header)),
803            })
804        }
805
806        /// Returns the buffer capacity needed for [`Self::upgrade`].
807        pub fn upgrade_size(stream_header: StreamHeader) -> usize {
808            stream_header.write_size()
809        }
810    }
811
812    impl StreamUniLocalH3 {
813        /// See [`Frame::write`].
814        ///
815        /// # Panics
816        ///
817        /// Panics if the stream kind is [`StreamKind::WebTransport`]. In that case, use `upgrade` method.
818        pub fn write_frame<W>(
819            &mut self,
820            frame: Frame,
821            bytes_writer: &mut W,
822        ) -> Result<(), EndOfBuffer>
823        where
824            W: BytesWriter,
825        {
826            assert!(!matches!(self.kind(), StreamKind::WebTransport));
827            frame.write(bytes_writer)?;
828            Ok(())
829        }
830
831        /// See [`Frame::write_async`].
832        ///
833        /// # Panics
834        ///
835        /// Panics if the stream kind is [`StreamKind::WebTransport`]. In that case, use `upgrade` method.
836        #[cfg(feature = "async")]
837        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
838        pub async fn write_frame_async<W>(
839            &mut self,
840            frame: Frame<'_>,
841            writer: &mut W,
842        ) -> Result<(), IoWriteError>
843        where
844            W: AsyncWrite + Unpin + ?Sized,
845        {
846            assert!(!matches!(self.kind(), StreamKind::WebTransport));
847            frame.write_async(writer).await?;
848            Ok(())
849        }
850
851        /// See [`Frame::write_to_buffer`].
852        ///
853        /// # Panics
854        ///
855        /// Panics if the stream kind is [`StreamKind::WebTransport`]. In that case, use `upgrade` method.
856        pub fn write_frame_to_buffer(
857            &mut self,
858            frame: Frame,
859            buffer_writer: &mut BufferWriter,
860        ) -> Result<(), EndOfBuffer> {
861            assert!(!matches!(self.kind(), StreamKind::WebTransport));
862            frame.write_to_buffer(buffer_writer)?;
863            Ok(())
864        }
865
866        /// Upgrades to a WebTransport stream.
867        ///
868        /// # Panics
869        ///
870        /// Panics if the stream kind is not [`StreamKind::WebTransport`].
871        pub fn upgrade(self) -> StreamUniLocalWT {
872            assert!(matches!(self.kind(), StreamKind::WebTransport));
873
874            StreamUniLocalWT {
875                kind: self.kind,
876                stage: WT::new(
877                    self.stage
878                        .stream_header()
879                        .expect("Unistream has header")
880                        .session_id()
881                        .expect("WebTransport type has session id"),
882                ),
883            }
884        }
885
886        /// Returns the [`StreamKind`] associated with the stream.
887        pub fn kind(&self) -> StreamKind {
888            self.stage
889                .stream_header()
890                .expect("Unistream has header")
891                .kind()
892        }
893
894        /// Returns the [`SessionId`] if stream is [`StreamKind::WebTransport`],
895        /// otherwise returns [`None`].
896        pub fn session_id(&self) -> Option<SessionId> {
897            self.stage
898                .stream_header()
899                .expect("Unistream has header")
900                .session_id()
901        }
902    }
903
904    impl StreamUniLocalWT {
905        /// Returns the [`SessionId`] associated with this stream.
906        #[inline(always)]
907        pub fn session_id(&self) -> SessionId {
908            self.stage.session_id()
909        }
910    }
911}
912
913/// Bidirectional local/remote stream implementations.
914///
915/// For WebTransport session request/response.
916pub mod session {
917    use super::*;
918    use types::*;
919
920    /// HTTP3 bidirectional stream carrying CONNECT request and response.
921    pub type StreamSession = Stream<Bi, Session>;
922
923    impl StreamSession {
924        /// See [`Frame::read`].
925        pub fn read_frame<'a, R>(
926            &self,
927            bytes_reader: &mut R,
928        ) -> Result<Option<Frame<'a>>, ErrorCode>
929        where
930            R: BytesReader<'a>,
931        {
932            loop {
933                match Frame::read(bytes_reader) {
934                    Ok(Some(frame)) => {
935                        return Ok(Some(self.validate_frame(frame)?));
936                    }
937                    Ok(None) => {
938                        return Ok(None);
939                    }
940                    Err(frame::ParseError::UnknownFrame) => {
941                        continue;
942                    }
943                    Err(frame::ParseError::InvalidSessionId) => {
944                        return Err(ErrorCode::Id);
945                    }
946                    Err(frame::ParseError::PayloadTooBig) => {
947                        return Err(ErrorCode::ExcessiveLoad);
948                    }
949                }
950            }
951        }
952
953        /// See [`Frame::read_async`].
954        #[cfg(feature = "async")]
955        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
956        pub async fn read_frame_async<'a, R>(
957            &self,
958            reader: &mut R,
959        ) -> Result<Frame<'a>, IoReadError>
960        where
961            R: AsyncRead + Unpin + ?Sized,
962        {
963            loop {
964                match Frame::read_async(reader).await {
965                    Ok(frame) => {
966                        return self.validate_frame(frame).map_err(IoReadError::H3);
967                    }
968                    Err(frame::IoReadError::Parse(frame::ParseError::UnknownFrame)) => {
969                        continue;
970                    }
971                    Err(frame::IoReadError::Parse(frame::ParseError::InvalidSessionId)) => {
972                        return Err(IoReadError::H3(ErrorCode::Id));
973                    }
974                    Err(frame::IoReadError::Parse(frame::ParseError::PayloadTooBig)) => {
975                        return Err(IoReadError::H3(ErrorCode::ExcessiveLoad));
976                    }
977                    Err(frame::IoReadError::IO(io_error)) => {
978                        if matches!(io_error, bytes::IoReadError::UnexpectedFin) {
979                            return Err(IoReadError::H3(ErrorCode::Frame));
980                        }
981
982                        return Err(IoReadError::IO(io_error));
983                    }
984                }
985            }
986        }
987
988        /// See [`Frame::read_from_buffer`].
989        pub fn read_frame_from_buffer<'a>(
990            &self,
991            buffer_reader: &mut BufferReader<'a>,
992        ) -> Result<Option<Frame<'a>>, ErrorCode> {
993            let mut buffer_reader_child = buffer_reader.child();
994
995            match self.read_frame(&mut *buffer_reader_child)? {
996                Some(frame) => {
997                    buffer_reader_child.commit();
998                    Ok(Some(frame))
999                }
1000                None => Ok(None),
1001            }
1002        }
1003
1004        /// See [`Frame::write`].
1005        pub fn write_frame<W>(&self, frame: Frame, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
1006        where
1007            W: BytesWriter,
1008        {
1009            frame.write(bytes_writer)
1010        }
1011
1012        /// See [`Frame::write_async`].
1013        #[cfg(feature = "async")]
1014        #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
1015        pub async fn write_frame_async<W>(
1016            &self,
1017            frame: Frame<'_>,
1018            writer: &mut W,
1019        ) -> Result<(), IoWriteError>
1020        where
1021            W: AsyncWrite + Unpin + ?Sized,
1022        {
1023            frame.write_async(writer).await
1024        }
1025
1026        /// See [`Frame::write_to_buffer`].
1027        pub fn write_frame_to_buffer(
1028            &self,
1029            frame: Frame,
1030            buffer_writer: &mut BufferWriter,
1031        ) -> Result<(), EndOfBuffer> {
1032            frame.write_to_buffer(buffer_writer)
1033        }
1034
1035        /// Returns the [`SessionRequest`] associated.
1036        #[inline(always)]
1037        pub fn request(&self) -> &SessionRequest {
1038            self.stage.request()
1039        }
1040
1041        fn validate_frame<'a>(&self, frame: Frame<'a>) -> Result<Frame<'a>, ErrorCode> {
1042            match frame.kind() {
1043                FrameKind::Data => Ok(frame),
1044                FrameKind::Headers => Ok(frame),
1045                FrameKind::Settings => Err(ErrorCode::FrameUnexpected),
1046                FrameKind::WebTransport => Err(ErrorCode::FrameUnexpected),
1047                FrameKind::Exercise(_) => Ok(frame),
1048            }
1049        }
1050    }
1051}
1052
1053/// Types and states of a stream.
1054pub mod types {
1055    use super::*;
1056
1057    /// QUIC stream type.
1058    pub struct Quic;
1059
1060    /// HTTP3 stream type.
1061    pub struct H3 {
1062        stream_header: Option<StreamHeader>,
1063        first_frame_done: bool,
1064    }
1065
1066    impl H3 {
1067        #[inline(always)]
1068        pub(super) fn new(stream_header: Option<StreamHeader>) -> Self {
1069            Self {
1070                stream_header,
1071                first_frame_done: false,
1072            }
1073        }
1074
1075        /// Sets the first frame to done.
1076        ///
1077        /// Returns the previous value (false it this is the first frame).
1078        #[inline(always)]
1079        pub(super) fn set_first_frame(&mut self) -> bool {
1080            std::mem::replace(&mut self.first_frame_done, true)
1081        }
1082
1083        #[inline(always)]
1084        pub(super) fn stream_header(&self) -> Option<&StreamHeader> {
1085            self.stream_header.as_ref()
1086        }
1087    }
1088
1089    /// WebTransport stream type.
1090    pub struct WT {
1091        session_id: SessionId,
1092    }
1093
1094    impl WT {
1095        #[inline(always)]
1096        pub(super) fn new(session_id: SessionId) -> Self {
1097            Self { session_id }
1098        }
1099
1100        #[inline(always)]
1101        pub(super) fn session_id(&self) -> SessionId {
1102            self.session_id
1103        }
1104    }
1105
1106    /// Session (HTTP3-CONNECT) stream type.
1107    #[derive(Debug)]
1108    pub struct Session {
1109        session_request: SessionRequest,
1110    }
1111
1112    impl Session {
1113        #[inline(always)]
1114        pub(super) fn new(session_request: SessionRequest) -> Self {
1115            Self { session_request }
1116        }
1117
1118        #[inline(always)]
1119        pub(super) fn request(&self) -> &SessionRequest {
1120            &self.session_request
1121        }
1122    }
1123
1124    /// Bidirectional stream type.
1125    #[derive(Debug)]
1126    pub struct Bi;
1127
1128    /// Unidirectional stream type.
1129    #[derive(Debug)]
1130    pub struct Uni;
1131
1132    /// Remote-initialized stream type.
1133    #[derive(Debug)]
1134    pub struct Remote;
1135
1136    /// Local-initialized stream type.
1137    #[derive(Debug)]
1138    pub struct Local;
1139
1140    /// Remote-initialized bi-directional stream type.
1141    #[derive(Debug)]
1142    pub struct BiRemote(Bi, Remote);
1143
1144    impl Default for BiRemote {
1145        #[inline(always)]
1146        fn default() -> Self {
1147            Self(Bi, Remote)
1148        }
1149    }
1150
1151    /// Local-initialized bi-directional stream type.
1152    pub struct BiLocal(Bi, Local);
1153
1154    impl Default for BiLocal {
1155        #[inline(always)]
1156        fn default() -> Self {
1157            Self(Bi, Local)
1158        }
1159    }
1160
1161    /// Remote-initialized uni-directional stream type.
1162    pub struct UniRemote(Uni, Remote);
1163
1164    impl Default for UniRemote {
1165        #[inline(always)]
1166        fn default() -> Self {
1167            Self(Uni, Remote)
1168        }
1169    }
1170
1171    /// Local-initialized uni-directional stream type.
1172    pub struct UniLocal(Uni, Local);
1173
1174    impl Default for UniLocal {
1175        #[inline(always)]
1176        fn default() -> Self {
1177            Self(Uni, Local)
1178        }
1179    }
1180}
1181
1182#[cfg(test)]
1183mod tests {
1184    use super::*;
1185    use crate::varint::VarInt;
1186    use std::borrow::Cow;
1187
1188    #[test]
1189    fn bi_remote_webtransport() {
1190        let mut buffer = Vec::new();
1191        Frame::new_webtransport(SessionId::maybe_invalid(VarInt::from_u32(0)))
1192            .write(&mut buffer)
1193            .unwrap();
1194
1195        let mut buffer_reader = BufferReader::new(buffer.as_slice());
1196        let mut stream = Stream::accept_bi().upgrade();
1197        let frame = stream
1198            .read_frame_from_buffer(&mut buffer_reader)
1199            .unwrap()
1200            .unwrap();
1201
1202        let stream = stream.upgrade(frame.session_id().unwrap());
1203
1204        assert_eq!(
1205            stream.session_id(),
1206            SessionId::maybe_invalid(VarInt::from_u32(0))
1207        );
1208    }
1209
1210    #[test]
1211    fn bi_remote_webtransport_not_first() {
1212        let mut buffer = Vec::new();
1213        Frame::new_exercise(VarInt::from_u32(0x21), Cow::Borrowed(b"Payload"))
1214            .write(&mut buffer)
1215            .unwrap();
1216        Frame::new_webtransport(SessionId::maybe_invalid(VarInt::from_u32(0)))
1217            .write(&mut buffer)
1218            .unwrap();
1219
1220        let mut buffer_reader = BufferReader::new(buffer.as_slice());
1221        let mut stream = Stream::accept_bi().upgrade();
1222        let frame = stream
1223            .read_frame_from_buffer(&mut buffer_reader)
1224            .unwrap()
1225            .unwrap();
1226
1227        assert!(matches!(frame.kind(), FrameKind::Exercise(_)));
1228
1229        let frame = stream.read_frame_from_buffer(&mut buffer_reader);
1230
1231        assert!(matches!(frame, Err(ErrorCode::Frame)));
1232    }
1233}