wtransport_proto/
stream_header.rs

1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::ids::InvalidSessionId;
7use crate::ids::SessionId;
8use crate::varint::VarInt;
9
10#[cfg(feature = "async")]
11use crate::bytes::AsyncRead;
12
13#[cfg(feature = "async")]
14use crate::bytes::AsyncWrite;
15
16#[cfg(feature = "async")]
17use crate::bytes;
18
19/// Error stream header parsing.
20#[derive(Debug, thiserror::Error)]
21pub enum ParseError {
22    /// Error for unknown stream type.
23    #[error("cannot parse HTTP3 stream header as ID is unknown")]
24    UnknownStream,
25
26    /// Error for invalid session ID.
27    #[error("cannot parse HTTP3 stream header as session ID is invalid")]
28    InvalidSessionId,
29}
30
31/// An error during stream header I/O read operation.
32#[cfg(feature = "async")]
33#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
34#[derive(Debug, thiserror::Error)]
35pub enum IoReadError {
36    /// Error during parsing stream header.
37    #[error(transparent)]
38    Parse(ParseError),
39
40    /// Error due to I/O operation.
41    #[error(transparent)]
42    IO(bytes::IoReadError),
43}
44
45#[cfg(feature = "async")]
46impl From<bytes::IoReadError> for IoReadError {
47    #[inline(always)]
48    fn from(io_error: bytes::IoReadError) -> Self {
49        IoReadError::IO(io_error)
50    }
51}
52
53/// An error during stream header I/O write operation.
54#[cfg(feature = "async")]
55pub type IoWriteError = bytes::IoWriteError;
56
57/// An HTTP3 stream type.
58#[derive(Copy, Clone, Debug)]
59pub enum StreamKind {
60    /// CONTROL stream type.
61    Control,
62
63    /// QPACK Encoder stream type.
64    QPackEncoder,
65
66    /// QPACK Decoder stream type.
67    QPackDecoder,
68
69    /// WebTransport stream type.
70    WebTransport,
71
72    /// Exercise stream.
73    Exercise(VarInt),
74}
75
76impl StreamKind {
77    /// Checks whether an `id` is valid for a [`StreamKind::Exercise`].
78    #[inline(always)]
79    pub const fn is_id_exercise(id: VarInt) -> bool {
80        id.into_inner() >= 0x21 && ((id.into_inner() - 0x21) % 0x1f == 0)
81    }
82
83    const fn parse(id: VarInt) -> Option<Self> {
84        match id {
85            stream_type_ids::CONTROL_STREAM => Some(StreamKind::Control),
86            stream_type_ids::QPACK_ENCODER_STREAM => Some(StreamKind::QPackEncoder),
87            stream_type_ids::QPACK_DECODER_STREAM => Some(StreamKind::QPackDecoder),
88            stream_type_ids::WEBTRANSPORT_STREAM => Some(StreamKind::WebTransport),
89            id if StreamKind::is_id_exercise(id) => Some(StreamKind::Exercise(id)),
90            _ => None,
91        }
92    }
93
94    const fn id(self) -> VarInt {
95        match self {
96            StreamKind::Control => stream_type_ids::CONTROL_STREAM,
97            StreamKind::QPackEncoder => stream_type_ids::QPACK_ENCODER_STREAM,
98            StreamKind::QPackDecoder => stream_type_ids::QPACK_DECODER_STREAM,
99            StreamKind::WebTransport => stream_type_ids::WEBTRANSPORT_STREAM,
100            StreamKind::Exercise(id) => id,
101        }
102    }
103}
104
105/// HTTP3 stream type.
106///
107/// *Unidirectional* HTTP3 streams have an header encoding the type.
108pub struct StreamHeader {
109    kind: StreamKind,
110    session_id: Option<SessionId>,
111}
112
113impl StreamHeader {
114    /// Maximum number of bytes a [`StreamHeader`] can take over network.
115    pub const MAX_SIZE: usize = 16;
116
117    /// Creates a new stream header of type [`StreamKind::Control`].
118    #[inline(always)]
119    pub fn new_control() -> Self {
120        Self::new(StreamKind::Control, None)
121    }
122
123    /// Creates a new stream header of type [`StreamKind::WebTransport`].
124    #[inline(always)]
125    pub fn new_webtransport(session_id: SessionId) -> Self {
126        Self::new(StreamKind::WebTransport, Some(session_id))
127    }
128
129    /// Reads a [`StreamHeader`] from a [`BytesReader`].
130    ///
131    /// It returns [`None`] if the `bytes_reader` does not contain enough bytes
132    /// to parse an entire header.
133    ///
134    /// In case [`None`] or [`Err`], `bytes_reader` might be partially read.
135    pub fn read<'a, R>(bytes_reader: &mut R) -> Result<Option<Self>, ParseError>
136    where
137        R: BytesReader<'a>,
138    {
139        let kind = match bytes_reader.get_varint() {
140            Some(kind_id) => StreamKind::parse(kind_id).ok_or(ParseError::UnknownStream)?,
141            None => return Ok(None),
142        };
143
144        let session_id = if matches!(kind, StreamKind::WebTransport) {
145            let session_id = match bytes_reader.get_varint() {
146                Some(session_id) => SessionId::try_from_varint(session_id)
147                    .map_err(|InvalidSessionId| ParseError::InvalidSessionId)?,
148                None => return Ok(None),
149            };
150
151            Some(session_id)
152        } else {
153            None
154        };
155
156        Ok(Some(Self::new(kind, session_id)))
157    }
158
159    /// Reads a [`StreamHeader`] from a `reader`.
160    #[cfg(feature = "async")]
161    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
162    pub async fn read_async<R>(reader: &mut R) -> Result<Self, IoReadError>
163    where
164        R: AsyncRead + Unpin + ?Sized,
165    {
166        use crate::bytes::BytesReaderAsync;
167
168        let kind_id = reader.get_varint().await?;
169        let kind =
170            StreamKind::parse(kind_id).ok_or(IoReadError::Parse(ParseError::UnknownStream))?;
171
172        let session_id = if matches!(kind, StreamKind::WebTransport) {
173            let session_id =
174                SessionId::try_from_varint(reader.get_varint().await.map_err(|e| match e {
175                    bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
176                    _ => e,
177                })?)
178                .map_err(|InvalidSessionId| IoReadError::Parse(ParseError::InvalidSessionId))?;
179
180            Some(session_id)
181        } else {
182            None
183        };
184
185        Ok(Self::new(kind, session_id))
186    }
187
188    /// Reads a [`StreamHeader`] from a [`BufferReader`].
189    ///
190    /// It returns [`None`] if the `buffer_reader` does not contain enough bytes
191    /// to parse an entire header.
192    ///
193    /// In case [`None`] or [`Err`], `buffer_reader` offset if not advanced.
194    pub fn read_from_buffer(buffer_reader: &mut BufferReader) -> Result<Option<Self>, ParseError> {
195        let mut buffer_reader_child = buffer_reader.child();
196
197        match Self::read(&mut *buffer_reader_child)? {
198            Some(header) => {
199                buffer_reader_child.commit();
200                Ok(Some(header))
201            }
202            None => Ok(None),
203        }
204    }
205
206    /// Writes a [`StreamHeader`] into a [`BytesWriter`].
207    ///
208    /// It returns [`Err`] if the `bytes_writer` does not have enough capacity
209    /// to write the entire header.
210    /// See [`Self::write_size`] to retrieve the exact amount of required capacity.
211    ///
212    /// In case [`Err`], `bytes_writer` might be partially written.
213    pub fn write<W>(&self, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
214    where
215        W: BytesWriter,
216    {
217        bytes_writer.put_varint(self.kind.id())?;
218
219        if let Some(session_id) = self.session_id() {
220            bytes_writer.put_varint(session_id.into_varint())?;
221        }
222
223        Ok(())
224    }
225
226    /// Writes a [`StreamHeader`] into a `writer`.
227    #[cfg(feature = "async")]
228    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
229    pub async fn write_async<W>(&self, writer: &mut W) -> Result<(), IoWriteError>
230    where
231        W: AsyncWrite + Unpin + ?Sized,
232    {
233        use crate::bytes::BytesWriterAsync;
234
235        writer.put_varint(self.kind.id()).await?;
236
237        if let Some(session_id) = self.session_id() {
238            writer.put_varint(session_id.into_varint()).await?;
239        }
240
241        Ok(())
242    }
243
244    /// Writes this [`StreamHeader`] into a buffer via [`BufferWriter`].
245    ///
246    /// In case [`Err`], `buffer_writer` is not advanced.
247    pub fn write_to_buffer(&self, buffer_writer: &mut BufferWriter) -> Result<(), EndOfBuffer> {
248        if buffer_writer.capacity() < self.write_size() {
249            return Err(EndOfBuffer);
250        }
251
252        self.write(buffer_writer)
253            .expect("Enough capacity for header");
254
255        Ok(())
256    }
257
258    /// Returns the needed capacity to write this stream header into a buffer.
259    pub fn write_size(&self) -> usize {
260        if let Some(session_id) = self.session_id() {
261            self.kind.id().size() + session_id.into_varint().size()
262        } else {
263            self.kind.id().size()
264        }
265    }
266
267    /// Returns the [`StreamKind`].
268    #[inline(always)]
269    pub const fn kind(&self) -> StreamKind {
270        self.kind
271    }
272
273    /// Returns the [`SessionId`] if stream is [`StreamKind::WebTransport`],
274    /// otherwise returns [`None`].
275    #[inline(always)]
276    pub fn session_id(&self) -> Option<SessionId> {
277        matches!(self.kind, StreamKind::WebTransport).then(|| {
278            self.session_id
279                .expect("WebTransport stream header contains session id")
280        })
281    }
282
283    fn new(kind: StreamKind, session_id: Option<SessionId>) -> Self {
284        if let StreamKind::Exercise(id) = kind {
285            debug_assert!(StreamKind::is_id_exercise(id));
286            debug_assert!(session_id.is_none());
287        } else if let StreamKind::WebTransport = kind {
288            debug_assert!(session_id.is_some());
289        } else {
290            debug_assert!(session_id.is_none());
291        }
292
293        Self { kind, session_id }
294    }
295
296    #[cfg(test)]
297    pub(crate) fn serialize_any(kind: VarInt) -> Vec<u8> {
298        let mut buffer = Vec::new();
299
300        Self {
301            kind: StreamKind::Exercise(kind),
302            session_id: None,
303        }
304        .write(&mut buffer)
305        .unwrap();
306
307        buffer
308    }
309
310    #[cfg(test)]
311    pub(crate) fn serialize_webtransport(session_id: SessionId) -> Vec<u8> {
312        let mut buffer = Vec::new();
313
314        Self {
315            kind: StreamKind::WebTransport,
316            session_id: Some(session_id),
317        }
318        .write(&mut buffer)
319        .unwrap();
320
321        buffer
322    }
323}
324
325mod stream_type_ids {
326    use crate::varint::VarInt;
327
328    pub const CONTROL_STREAM: VarInt = VarInt::from_u32(0x0);
329    pub const QPACK_ENCODER_STREAM: VarInt = VarInt::from_u32(0x02);
330    pub const QPACK_DECODER_STREAM: VarInt = VarInt::from_u32(0x03);
331    pub const WEBTRANSPORT_STREAM: VarInt = VarInt::from_u32(0x54);
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn control() {
340        let stream_header = StreamHeader::new_control();
341        assert!(matches!(stream_header.kind(), StreamKind::Control));
342        assert!(stream_header.session_id().is_none());
343
344        let stream_header = utils::assert_serde(stream_header);
345        assert!(matches!(stream_header.kind(), StreamKind::Control));
346        assert!(stream_header.session_id().is_none());
347    }
348
349    #[tokio::test]
350    async fn control_async() {
351        let stream_header = StreamHeader::new_control();
352        assert!(matches!(stream_header.kind(), StreamKind::Control));
353        assert!(stream_header.session_id().is_none());
354
355        let stream_header = utils::assert_serde_async(stream_header).await;
356        assert!(matches!(stream_header.kind(), StreamKind::Control));
357        assert!(stream_header.session_id().is_none());
358    }
359
360    #[test]
361    fn webtransport() {
362        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
363
364        let stream_header = StreamHeader::new_webtransport(session_id);
365        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
366        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
367
368        let stream_header = utils::assert_serde(stream_header);
369        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
370        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
371    }
372
373    #[tokio::test]
374    async fn webtransport_async() {
375        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
376
377        let stream_header = StreamHeader::new_webtransport(session_id);
378        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
379        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
380
381        let stream_header = utils::assert_serde_async(stream_header).await;
382        assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
383        assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
384    }
385
386    #[test]
387    fn read_eof() {
388        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
389        assert!(StreamHeader::read(&mut &buffer[..buffer.len() - 1])
390            .unwrap()
391            .is_none());
392    }
393
394    #[tokio::test]
395    async fn read_eof_async() {
396        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
397
398        for len in 0..buffer.len() {
399            let result = StreamHeader::read_async(&mut &buffer[..len]).await;
400
401            match len {
402                0 => assert!(matches!(
403                    result,
404                    Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
405                )),
406                _ => assert!(matches!(
407                    result,
408                    Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
409                )),
410            }
411        }
412    }
413
414    #[tokio::test]
415    async fn read_eof_webtransport_async() {
416        let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
417        let buffer = StreamHeader::serialize_webtransport(session_id);
418
419        for len in 0..buffer.len() {
420            let result = StreamHeader::read_async(&mut &buffer[..len]).await;
421
422            match len {
423                0 => assert!(matches!(
424                    result,
425                    Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
426                )),
427                _ => assert!(matches!(
428                    result,
429                    Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
430                )),
431            }
432        }
433    }
434
435    #[test]
436    fn unknown_stream() {
437        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
438
439        assert!(matches!(
440            StreamHeader::read(&mut buffer.as_slice()),
441            Err(ParseError::UnknownStream)
442        ));
443    }
444
445    #[tokio::test]
446    async fn unknown_stream_async() {
447        let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
448
449        assert!(matches!(
450            StreamHeader::read_async(&mut buffer.as_slice()).await,
451            Err(IoReadError::Parse(ParseError::UnknownStream))
452        ));
453    }
454
455    #[test]
456    fn invalid_session_id() {
457        let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
458        let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
459
460        assert!(matches!(
461            StreamHeader::read(&mut buffer.as_slice()),
462            Err(ParseError::InvalidSessionId)
463        ));
464    }
465
466    #[tokio::test]
467    async fn invalid_session_id_async() {
468        let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
469        let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
470
471        assert!(matches!(
472            StreamHeader::read_async(&mut buffer.as_slice()).await,
473            Err(IoReadError::Parse(ParseError::InvalidSessionId))
474        ));
475    }
476
477    mod utils {
478        use super::*;
479
480        pub fn assert_serde(stream_header: StreamHeader) -> StreamHeader {
481            let mut buffer = Vec::new();
482
483            stream_header.write(&mut buffer).unwrap();
484            assert_eq!(buffer.len(), stream_header.write_size());
485            assert!(buffer.len() <= StreamHeader::MAX_SIZE);
486
487            let mut buffer = buffer.as_slice();
488            let stream_header = StreamHeader::read(&mut buffer).unwrap().unwrap();
489            assert!(buffer.is_empty());
490
491            stream_header
492        }
493
494        #[cfg(feature = "async")]
495        pub async fn assert_serde_async(stream_header: StreamHeader) -> StreamHeader {
496            let mut buffer = Vec::new();
497
498            stream_header.write_async(&mut buffer).await.unwrap();
499            assert_eq!(buffer.len(), stream_header.write_size());
500            assert!(buffer.len() <= StreamHeader::MAX_SIZE);
501
502            let mut buffer = buffer.as_slice();
503            let stream_header = StreamHeader::read_async(&mut buffer).await.unwrap();
504            assert!(buffer.is_empty());
505
506            stream_header
507        }
508    }
509}