1use crate::bytes::BufferReader;
2use crate::bytes::BufferWriter;
3use crate::bytes::BytesReader;
4use crate::bytes::BytesWriter;
5use crate::bytes::EndOfBuffer;
6use crate::ids::InvalidSessionId;
7use crate::ids::SessionId;
8use crate::varint::VarInt;
9
10#[cfg(feature = "async")]
11use crate::bytes::AsyncRead;
12
13#[cfg(feature = "async")]
14use crate::bytes::AsyncWrite;
15
16#[cfg(feature = "async")]
17use crate::bytes;
18
19#[derive(Debug, thiserror::Error)]
21pub enum ParseError {
22 #[error("cannot parse HTTP3 stream header as ID is unknown")]
24 UnknownStream,
25
26 #[error("cannot parse HTTP3 stream header as session ID is invalid")]
28 InvalidSessionId,
29}
30
31#[cfg(feature = "async")]
33#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
34#[derive(Debug, thiserror::Error)]
35pub enum IoReadError {
36 #[error(transparent)]
38 Parse(ParseError),
39
40 #[error(transparent)]
42 IO(bytes::IoReadError),
43}
44
45#[cfg(feature = "async")]
46impl From<bytes::IoReadError> for IoReadError {
47 #[inline(always)]
48 fn from(io_error: bytes::IoReadError) -> Self {
49 IoReadError::IO(io_error)
50 }
51}
52
53#[cfg(feature = "async")]
55pub type IoWriteError = bytes::IoWriteError;
56
57#[derive(Copy, Clone, Debug)]
59pub enum StreamKind {
60 Control,
62
63 QPackEncoder,
65
66 QPackDecoder,
68
69 WebTransport,
71
72 Exercise(VarInt),
74}
75
76impl StreamKind {
77 #[inline(always)]
79 pub const fn is_id_exercise(id: VarInt) -> bool {
80 id.into_inner() >= 0x21 && ((id.into_inner() - 0x21) % 0x1f == 0)
81 }
82
83 const fn parse(id: VarInt) -> Option<Self> {
84 match id {
85 stream_type_ids::CONTROL_STREAM => Some(StreamKind::Control),
86 stream_type_ids::QPACK_ENCODER_STREAM => Some(StreamKind::QPackEncoder),
87 stream_type_ids::QPACK_DECODER_STREAM => Some(StreamKind::QPackDecoder),
88 stream_type_ids::WEBTRANSPORT_STREAM => Some(StreamKind::WebTransport),
89 id if StreamKind::is_id_exercise(id) => Some(StreamKind::Exercise(id)),
90 _ => None,
91 }
92 }
93
94 const fn id(self) -> VarInt {
95 match self {
96 StreamKind::Control => stream_type_ids::CONTROL_STREAM,
97 StreamKind::QPackEncoder => stream_type_ids::QPACK_ENCODER_STREAM,
98 StreamKind::QPackDecoder => stream_type_ids::QPACK_DECODER_STREAM,
99 StreamKind::WebTransport => stream_type_ids::WEBTRANSPORT_STREAM,
100 StreamKind::Exercise(id) => id,
101 }
102 }
103}
104
105pub struct StreamHeader {
109 kind: StreamKind,
110 session_id: Option<SessionId>,
111}
112
113impl StreamHeader {
114 pub const MAX_SIZE: usize = 16;
116
117 #[inline(always)]
119 pub fn new_control() -> Self {
120 Self::new(StreamKind::Control, None)
121 }
122
123 #[inline(always)]
125 pub fn new_webtransport(session_id: SessionId) -> Self {
126 Self::new(StreamKind::WebTransport, Some(session_id))
127 }
128
129 pub fn read<'a, R>(bytes_reader: &mut R) -> Result<Option<Self>, ParseError>
136 where
137 R: BytesReader<'a>,
138 {
139 let kind = match bytes_reader.get_varint() {
140 Some(kind_id) => StreamKind::parse(kind_id).ok_or(ParseError::UnknownStream)?,
141 None => return Ok(None),
142 };
143
144 let session_id = if matches!(kind, StreamKind::WebTransport) {
145 let session_id = match bytes_reader.get_varint() {
146 Some(session_id) => SessionId::try_from_varint(session_id)
147 .map_err(|InvalidSessionId| ParseError::InvalidSessionId)?,
148 None => return Ok(None),
149 };
150
151 Some(session_id)
152 } else {
153 None
154 };
155
156 Ok(Some(Self::new(kind, session_id)))
157 }
158
159 #[cfg(feature = "async")]
161 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
162 pub async fn read_async<R>(reader: &mut R) -> Result<Self, IoReadError>
163 where
164 R: AsyncRead + Unpin + ?Sized,
165 {
166 use crate::bytes::BytesReaderAsync;
167
168 let kind_id = reader.get_varint().await?;
169 let kind =
170 StreamKind::parse(kind_id).ok_or(IoReadError::Parse(ParseError::UnknownStream))?;
171
172 let session_id = if matches!(kind, StreamKind::WebTransport) {
173 let session_id =
174 SessionId::try_from_varint(reader.get_varint().await.map_err(|e| match e {
175 bytes::IoReadError::ImmediateFin => bytes::IoReadError::UnexpectedFin,
176 _ => e,
177 })?)
178 .map_err(|InvalidSessionId| IoReadError::Parse(ParseError::InvalidSessionId))?;
179
180 Some(session_id)
181 } else {
182 None
183 };
184
185 Ok(Self::new(kind, session_id))
186 }
187
188 pub fn read_from_buffer(buffer_reader: &mut BufferReader) -> Result<Option<Self>, ParseError> {
195 let mut buffer_reader_child = buffer_reader.child();
196
197 match Self::read(&mut *buffer_reader_child)? {
198 Some(header) => {
199 buffer_reader_child.commit();
200 Ok(Some(header))
201 }
202 None => Ok(None),
203 }
204 }
205
206 pub fn write<W>(&self, bytes_writer: &mut W) -> Result<(), EndOfBuffer>
214 where
215 W: BytesWriter,
216 {
217 bytes_writer.put_varint(self.kind.id())?;
218
219 if let Some(session_id) = self.session_id() {
220 bytes_writer.put_varint(session_id.into_varint())?;
221 }
222
223 Ok(())
224 }
225
226 #[cfg(feature = "async")]
228 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
229 pub async fn write_async<W>(&self, writer: &mut W) -> Result<(), IoWriteError>
230 where
231 W: AsyncWrite + Unpin + ?Sized,
232 {
233 use crate::bytes::BytesWriterAsync;
234
235 writer.put_varint(self.kind.id()).await?;
236
237 if let Some(session_id) = self.session_id() {
238 writer.put_varint(session_id.into_varint()).await?;
239 }
240
241 Ok(())
242 }
243
244 pub fn write_to_buffer(&self, buffer_writer: &mut BufferWriter) -> Result<(), EndOfBuffer> {
248 if buffer_writer.capacity() < self.write_size() {
249 return Err(EndOfBuffer);
250 }
251
252 self.write(buffer_writer)
253 .expect("Enough capacity for header");
254
255 Ok(())
256 }
257
258 pub fn write_size(&self) -> usize {
260 if let Some(session_id) = self.session_id() {
261 self.kind.id().size() + session_id.into_varint().size()
262 } else {
263 self.kind.id().size()
264 }
265 }
266
267 #[inline(always)]
269 pub const fn kind(&self) -> StreamKind {
270 self.kind
271 }
272
273 #[inline(always)]
276 pub fn session_id(&self) -> Option<SessionId> {
277 matches!(self.kind, StreamKind::WebTransport).then(|| {
278 self.session_id
279 .expect("WebTransport stream header contains session id")
280 })
281 }
282
283 fn new(kind: StreamKind, session_id: Option<SessionId>) -> Self {
284 if let StreamKind::Exercise(id) = kind {
285 debug_assert!(StreamKind::is_id_exercise(id));
286 debug_assert!(session_id.is_none());
287 } else if let StreamKind::WebTransport = kind {
288 debug_assert!(session_id.is_some());
289 } else {
290 debug_assert!(session_id.is_none());
291 }
292
293 Self { kind, session_id }
294 }
295
296 #[cfg(test)]
297 pub(crate) fn serialize_any(kind: VarInt) -> Vec<u8> {
298 let mut buffer = Vec::new();
299
300 Self {
301 kind: StreamKind::Exercise(kind),
302 session_id: None,
303 }
304 .write(&mut buffer)
305 .unwrap();
306
307 buffer
308 }
309
310 #[cfg(test)]
311 pub(crate) fn serialize_webtransport(session_id: SessionId) -> Vec<u8> {
312 let mut buffer = Vec::new();
313
314 Self {
315 kind: StreamKind::WebTransport,
316 session_id: Some(session_id),
317 }
318 .write(&mut buffer)
319 .unwrap();
320
321 buffer
322 }
323}
324
325mod stream_type_ids {
326 use crate::varint::VarInt;
327
328 pub const CONTROL_STREAM: VarInt = VarInt::from_u32(0x0);
329 pub const QPACK_ENCODER_STREAM: VarInt = VarInt::from_u32(0x02);
330 pub const QPACK_DECODER_STREAM: VarInt = VarInt::from_u32(0x03);
331 pub const WEBTRANSPORT_STREAM: VarInt = VarInt::from_u32(0x54);
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn control() {
340 let stream_header = StreamHeader::new_control();
341 assert!(matches!(stream_header.kind(), StreamKind::Control));
342 assert!(stream_header.session_id().is_none());
343
344 let stream_header = utils::assert_serde(stream_header);
345 assert!(matches!(stream_header.kind(), StreamKind::Control));
346 assert!(stream_header.session_id().is_none());
347 }
348
349 #[tokio::test]
350 async fn control_async() {
351 let stream_header = StreamHeader::new_control();
352 assert!(matches!(stream_header.kind(), StreamKind::Control));
353 assert!(stream_header.session_id().is_none());
354
355 let stream_header = utils::assert_serde_async(stream_header).await;
356 assert!(matches!(stream_header.kind(), StreamKind::Control));
357 assert!(stream_header.session_id().is_none());
358 }
359
360 #[test]
361 fn webtransport() {
362 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
363
364 let stream_header = StreamHeader::new_webtransport(session_id);
365 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
366 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
367
368 let stream_header = utils::assert_serde(stream_header);
369 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
370 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
371 }
372
373 #[tokio::test]
374 async fn webtransport_async() {
375 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
376
377 let stream_header = StreamHeader::new_webtransport(session_id);
378 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
379 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
380
381 let stream_header = utils::assert_serde_async(stream_header).await;
382 assert!(matches!(stream_header.kind(), StreamKind::WebTransport));
383 assert!(matches!(stream_header.session_id(), Some(x) if x == session_id));
384 }
385
386 #[test]
387 fn read_eof() {
388 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
389 assert!(StreamHeader::read(&mut &buffer[..buffer.len() - 1])
390 .unwrap()
391 .is_none());
392 }
393
394 #[tokio::test]
395 async fn read_eof_async() {
396 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
397
398 for len in 0..buffer.len() {
399 let result = StreamHeader::read_async(&mut &buffer[..len]).await;
400
401 match len {
402 0 => assert!(matches!(
403 result,
404 Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
405 )),
406 _ => assert!(matches!(
407 result,
408 Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
409 )),
410 }
411 }
412 }
413
414 #[tokio::test]
415 async fn read_eof_webtransport_async() {
416 let session_id = SessionId::try_from_varint(VarInt::from_u32(0)).unwrap();
417 let buffer = StreamHeader::serialize_webtransport(session_id);
418
419 for len in 0..buffer.len() {
420 let result = StreamHeader::read_async(&mut &buffer[..len]).await;
421
422 match len {
423 0 => assert!(matches!(
424 result,
425 Err(IoReadError::IO(bytes::IoReadError::ImmediateFin))
426 )),
427 _ => assert!(matches!(
428 result,
429 Err(IoReadError::IO(bytes::IoReadError::UnexpectedFin))
430 )),
431 }
432 }
433 }
434
435 #[test]
436 fn unknown_stream() {
437 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
438
439 assert!(matches!(
440 StreamHeader::read(&mut buffer.as_slice()),
441 Err(ParseError::UnknownStream)
442 ));
443 }
444
445 #[tokio::test]
446 async fn unknown_stream_async() {
447 let buffer = StreamHeader::serialize_any(VarInt::from_u32(0x0042_4242));
448
449 assert!(matches!(
450 StreamHeader::read_async(&mut buffer.as_slice()).await,
451 Err(IoReadError::Parse(ParseError::UnknownStream))
452 ));
453 }
454
455 #[test]
456 fn invalid_session_id() {
457 let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
458 let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
459
460 assert!(matches!(
461 StreamHeader::read(&mut buffer.as_slice()),
462 Err(ParseError::InvalidSessionId)
463 ));
464 }
465
466 #[tokio::test]
467 async fn invalid_session_id_async() {
468 let invalid_session_id = SessionId::maybe_invalid(VarInt::from_u32(1));
469 let buffer = StreamHeader::serialize_webtransport(invalid_session_id);
470
471 assert!(matches!(
472 StreamHeader::read_async(&mut buffer.as_slice()).await,
473 Err(IoReadError::Parse(ParseError::InvalidSessionId))
474 ));
475 }
476
477 mod utils {
478 use super::*;
479
480 pub fn assert_serde(stream_header: StreamHeader) -> StreamHeader {
481 let mut buffer = Vec::new();
482
483 stream_header.write(&mut buffer).unwrap();
484 assert_eq!(buffer.len(), stream_header.write_size());
485 assert!(buffer.len() <= StreamHeader::MAX_SIZE);
486
487 let mut buffer = buffer.as_slice();
488 let stream_header = StreamHeader::read(&mut buffer).unwrap().unwrap();
489 assert!(buffer.is_empty());
490
491 stream_header
492 }
493
494 #[cfg(feature = "async")]
495 pub async fn assert_serde_async(stream_header: StreamHeader) -> StreamHeader {
496 let mut buffer = Vec::new();
497
498 stream_header.write_async(&mut buffer).await.unwrap();
499 assert_eq!(buffer.len(), stream_header.write_size());
500 assert!(buffer.len() <= StreamHeader::MAX_SIZE);
501
502 let mut buffer = buffer.as_slice();
503 let stream_header = StreamHeader::read_async(&mut buffer).await.unwrap();
504 assert!(buffer.is_empty());
505
506 stream_header
507 }
508 }
509}