wtransport_proto/
qpack.rs

1use crate::bytes::BufferReader;
2use crate::bytes::BytesReader;
3use crate::bytes::BytesWriter;
4use crate::bytes::EndOfBuffer;
5use std::collections::HashMap;
6
7/// Usage: `const_assert!(Var1: Ty, Var2: Ty, ... => expression)`
8macro_rules! const_assert {
9    ($($list:ident : $ty:ty),* => $expr:expr) => {{
10        struct Assert<$(const $list: usize,)*>;
11        impl<$(const $list: $ty,)*> Assert<$($list,)*> {
12            const OK: u8 = 0 - !($expr) as u8;
13        }
14        Assert::<$($list,)*>::OK
15    }};
16    ($expr:expr) => {
17        const OK: u8 = 0 - !($expr) as u8;
18    };
19}
20
21/// Error during decoding operation.
22///
23/// Generated from [`Decoder::decode`].
24#[derive(Debug, thiserror::Error)]
25pub enum DecodingError {
26    /// During decoding input data raeched unexpected EOF.
27    #[error("end of stream reached prematurely")]
28    UnexpectedFin,
29
30    /// Integer decoding produced an overflow.
31    #[error("integer overflow")]
32    IntegerOverflow,
33
34    /// String decoding is invalid (UTF-8 fail or Huffman).
35    #[error("invalid string decoding")]
36    InvalidString,
37
38    /// Encoded data requires dynamic table. It is not supported.
39    #[error("dynamic table is not supported")]
40    DynamicNotSupported,
41
42    /// Index is out-of-bound in the static table.
43    #[error("index not found in the static table")]
44    IndexNotfound,
45}
46
47enum FieldLineType {
48    Indexed,
49    IndexedPost,
50    LiteralRefName,
51    LiteralPostRefName,
52    LiteralLitName,
53}
54
55/// QPACK decoder.
56///
57/// It only supports stateless decoding, so any data requiring
58/// dynamic table will end up in a [`DecodingError::DynamicNotSupported`]
59/// error.
60pub struct Decoder;
61
62impl Decoder {
63    /// Decodes data stream.
64    ///
65    /// Result is an hash-map of headers.
66    pub fn decode<D>(data: D) -> Result<HashMap<String, String>, DecodingError>
67    where
68        D: AsRef<[u8]>,
69    {
70        let mut buffer_reader = BufferReader::new(data.as_ref());
71
72        Self::decode_integer::<8, _>(&mut buffer_reader)?;
73        Self::decode_integer::<7, _>(&mut buffer_reader)?;
74
75        let mut headers = HashMap::new();
76
77        while buffer_reader.capacity() > 0 {
78            let field = buffer_reader.buffer_remaining()[0];
79
80            match Self::decode_field_line_type(field) {
81                FieldLineType::Indexed => {
82                    let is_dynamic = field & 0b0100_0000 == 0;
83                    if is_dynamic {
84                        return Err(DecodingError::DynamicNotSupported);
85                    }
86
87                    let index = Self::decode_integer::<6, _>(&mut buffer_reader)?.1;
88                    let (key, value) =
89                        StaticTable::lookup_field(index).ok_or(DecodingError::IndexNotfound)?;
90                    headers.insert(key.to_string(), value.to_string());
91                }
92                FieldLineType::IndexedPost => {
93                    return Err(DecodingError::DynamicNotSupported);
94                }
95                FieldLineType::LiteralRefName => {
96                    let is_dynamic = field & 0b0001_0000 == 0;
97                    if is_dynamic {
98                        return Err(DecodingError::DynamicNotSupported);
99                    }
100
101                    let index = Self::decode_integer::<4, _>(&mut buffer_reader)?.1;
102                    let key = StaticTable::lookup_field(index)
103                        .ok_or(DecodingError::IndexNotfound)?
104                        .0;
105                    let value = Self::decode_string::<7, _>(&mut buffer_reader)?;
106
107                    headers.insert(key.to_string(), value);
108                }
109                FieldLineType::LiteralPostRefName => {
110                    return Err(DecodingError::DynamicNotSupported);
111                }
112                FieldLineType::LiteralLitName => {
113                    let key = Self::decode_string::<3, _>(&mut buffer_reader)?;
114                    let value = Self::decode_string::<7, _>(&mut buffer_reader)?;
115
116                    headers.insert(key, value);
117                }
118            }
119        }
120
121        Ok(headers)
122    }
123
124    fn decode_field_line_type(byte: u8) -> FieldLineType {
125        const MASK_INDEXED: u8 = 0b0000_0001;
126        const MASK_INDEXED_POST: u8 = 0b0000_0001;
127        const MASK_LITERAL_REF_NAME: u8 = 0b0000_0001;
128        const MASK_LITERAL_POST_REF_NAME: u8 = 0b0000_0000;
129        const MASK_LITERAL_LIT_NAME: u8 = 0b0000_0001;
130
131        if byte >> 7 == MASK_INDEXED {
132            FieldLineType::Indexed
133        } else if byte >> 4 == MASK_INDEXED_POST {
134            FieldLineType::IndexedPost
135        } else if byte >> 6 == MASK_LITERAL_REF_NAME {
136            FieldLineType::LiteralRefName
137        } else if byte >> 4 == MASK_LITERAL_POST_REF_NAME {
138            FieldLineType::LiteralPostRefName
139        } else if byte >> 5 == MASK_LITERAL_LIT_NAME {
140            FieldLineType::LiteralLitName
141        } else {
142            unreachable!()
143        }
144    }
145
146    fn decode_integer<'a, const N: usize, R>(
147        bytes_reader: &mut R,
148    ) -> Result<(u8, usize), DecodingError>
149    where
150        R: BytesReader<'a>,
151    {
152        const_assert!(N: usize => N <= 8 && N >= 1);
153
154        let byte = bytes_reader
155            .get_bytes(1)
156            .ok_or(DecodingError::UnexpectedFin)?[0] as usize;
157
158        let mask = (0x01 << N) - 1;
159        let flags = (byte >> N) as u8;
160        let mut value = byte & mask;
161
162        if value != mask {
163            return Ok((flags, value));
164        }
165
166        let mut power = 0;
167        loop {
168            let byte = bytes_reader
169                .get_bytes(1)
170                .ok_or(DecodingError::UnexpectedFin)?[0] as usize;
171
172            value = value
173                .checked_add((byte & 0x7F) << power)
174                .ok_or(DecodingError::IntegerOverflow)?;
175
176            power += 7;
177
178            if byte & 0x80 == 0 {
179                break;
180            }
181        }
182
183        Ok((flags, value))
184    }
185
186    fn decode_string<'a, const N: usize, R>(bytes_reader: &mut R) -> Result<String, DecodingError>
187    where
188        R: BytesReader<'a>,
189    {
190        let (flags, string_len) = Self::decode_integer::<N, R>(bytes_reader)?;
191
192        let is_huffman = flags & 0x1 == 0x1;
193
194        let string_data = bytes_reader
195            .get_bytes(string_len)
196            .ok_or(DecodingError::UnexpectedFin)?;
197
198        let string_data = if is_huffman {
199            let mut string_dec = Vec::with_capacity(string_len);
200
201            httlib_huffman::decode(
202                string_data,
203                &mut string_dec,
204                httlib_huffman::DecoderSpeed::OneBit,
205            )
206            .map_err(|_| DecodingError::InvalidString)?;
207
208            string_dec
209        } else {
210            string_data.to_vec()
211        };
212
213        String::from_utf8(string_data).map_err(|_| DecodingError::InvalidString)
214    }
215}
216
217/// QPACK encoder.
218///
219/// It only supports stateless decoding, so all encoding
220/// will be performed by means of the static table.
221pub struct Encoder;
222
223impl Encoder {
224    /// Encodes headers into data to be transmitted.
225    pub fn encode<H, K, V>(headers: H) -> Box<[u8]>
226    where
227        H: IntoIterator<Item = (K, V)>,
228        K: AsRef<str>,
229        V: AsRef<str>,
230    {
231        let mut buffer = Vec::new();
232
233        Self::encode_integer::<8, _>(0, 0, &mut buffer).expect("vec does not eof");
234        Self::encode_integer::<7, _>(0, 0, &mut buffer).expect("vec does not eof");
235
236        for (key, value) in headers.into_iter() {
237            match StaticTable::lookup_index(key.as_ref(), value.as_ref()) {
238                Some(LookupIndexFound::KeyValue(index)) => {
239                    Self::encode_integer::<6, _>(0b11, index, &mut buffer)
240                        .expect("vec does not eof");
241                }
242                Some(LookupIndexFound::KeyOnly(index)) => {
243                    Self::encode_integer::<4, _>(0b0101, index, &mut buffer)
244                        .expect("vec does not eof");
245                    Self::encode_string::<7, _, _>(0, value, &mut buffer)
246                        .expect("vec does not eof");
247                }
248                None => {
249                    Self::encode_string::<3, _, _>(0b10, key, &mut buffer)
250                        .expect("vec does not eof");
251                    Self::encode_string::<7, _, _>(0, value, &mut buffer)
252                        .expect("vec does not eof");
253                }
254            }
255        }
256
257        buffer.into_boxed_slice()
258    }
259
260    fn encode_integer<const N: usize, W>(
261        flags: u8,
262        value: usize,
263        bytes_writer: &mut W,
264    ) -> Result<(), EndOfBuffer>
265    where
266        W: BytesWriter,
267    {
268        const_assert!(N: usize => N <= 8 && N >= 1);
269
270        let mask = (0x01 << N) - 1;
271        let flags = ((flags as usize) << N) as u8;
272
273        if value < mask {
274            bytes_writer.put_bytes(&[flags | value as u8])?;
275            return Ok(());
276        }
277
278        bytes_writer.put_bytes(&[flags | mask as u8])?;
279
280        let mut rem = value - mask;
281        while rem >= 0x80 {
282            let byte = rem as u8 | 0x80;
283            bytes_writer.put_bytes(&[byte])?;
284            rem >>= 7;
285        }
286
287        bytes_writer.put_bytes(&[rem as u8])?;
288
289        Ok(())
290    }
291
292    fn encode_string<const N: usize, S, W>(
293        flags: u8,
294        value: S,
295        bytes_writer: &mut W,
296    ) -> Result<(), EndOfBuffer>
297    where
298        S: AsRef<str>,
299        W: BytesWriter,
300    {
301        let value = value.as_ref().as_bytes();
302
303        let mut huffman_buffer = Vec::new();
304
305        let (is_huffman, string_data) = match httlib_huffman::encode(value, &mut huffman_buffer) {
306            Ok(()) => {
307                if huffman_buffer.len() < value.len() {
308                    (true, huffman_buffer.as_slice())
309                } else {
310                    (false, value)
311                }
312            }
313            Err(_) => (false, value),
314        };
315
316        let flags = (flags << 1) | (is_huffman as u8);
317
318        Self::encode_integer::<N, _>(flags, string_data.len(), bytes_writer)?;
319        bytes_writer.put_bytes(string_data)
320    }
321}
322
323enum LookupIndexFound {
324    KeyValue(usize),
325    KeyOnly(usize),
326}
327
328struct StaticTable;
329
330impl StaticTable {
331    const STATIC_TABLE: &'static [(&'static str, &'static str); 99] = &[
332        (":authority", ""),
333        (":path", "/"),
334        ("age", "0"),
335        ("content-disposition", ""),
336        ("content-length", "0"),
337        ("cookie", ""),
338        ("date", ""),
339        ("etag", ""),
340        ("if-modified-since", ""),
341        ("if-none-match", ""),
342        ("last-modified", ""),
343        ("link", ""),
344        ("location", ""),
345        ("referer", ""),
346        ("set-cookie", ""),
347        (":method", "CONNECT"),
348        (":method", "DELETE"),
349        (":method", "GET"),
350        (":method", "HEAD"),
351        (":method", "OPTIONS"),
352        (":method", "POST"),
353        (":method", "PUT"),
354        (":scheme", "http"),
355        (":scheme", "https"),
356        (":status", "103"),
357        (":status", "200"),
358        (":status", "304"),
359        (":status", "404"),
360        (":status", "503"),
361        ("accept", "*/*"),
362        ("accept", "application/dns-message"),
363        ("accept-encoding", "gzip, deflate, br"),
364        ("accept-ranges", "bytes"),
365        ("access-control-allow-headers", "cache-control"),
366        ("access-control-allow-headers", "content-type"),
367        ("access-control-allow-origin", "*"),
368        ("cache-control", "max-age=0"),
369        ("cache-control", "max-age=2592000"),
370        ("cache-control", "max-age=604800"),
371        ("cache-control", "no-cache"),
372        ("cache-control", "no-store"),
373        ("cache-control", "public, max-age=31536000"),
374        ("content-encoding", "br"),
375        ("content-encoding", "gzip"),
376        ("content-type", "application/dns-message"),
377        ("content-type", "application/javascript"),
378        ("content-type", "application/json"),
379        ("content-type", "application/x-www-form-urlencoded"),
380        ("content-type", "image/gif"),
381        ("content-type", "image/jpeg"),
382        ("content-type", "image/png"),
383        ("content-type", "text/css"),
384        ("content-type", "text/html; charset=utf-8"),
385        ("content-type", "text/plain"),
386        ("content-type", "text/plain;charset=utf-8"),
387        ("range", "bytes=0-"),
388        ("strict-transport-security", "max-age=31536000"),
389        (
390            "strict-transport-security",
391            "max-age=31536000; includesubdomains",
392        ),
393        (
394            "strict-transport-security",
395            "max-age=31536000; includesubdomains; preload",
396        ),
397        ("vary", "accept-encoding"),
398        ("vary", "origin"),
399        ("x-content-type-options", "nosniff"),
400        ("x-xss-protection", "1; mode=block"),
401        (":status", "100"),
402        (":status", "204"),
403        (":status", "206"),
404        (":status", "302"),
405        (":status", "400"),
406        (":status", "403"),
407        (":status", "421"),
408        (":status", "425"),
409        (":status", "500"),
410        ("accept-language", ""),
411        ("access-control-allow-credentials", "FALSE"),
412        ("access-control-allow-credentials", "TRUE"),
413        ("access-control-allow-headers", "*"),
414        ("access-control-allow-methods", "get"),
415        ("access-control-allow-methods", "get, post, options"),
416        ("access-control-allow-methods", "options"),
417        ("access-control-expose-headers", "content-length"),
418        ("access-control-request-headers", "content-type"),
419        ("access-control-request-method", "get"),
420        ("access-control-request-method", "post"),
421        ("alt-svc", "clear"),
422        ("authorization", ""),
423        (
424            "content-security-policy",
425            "script-src 'none'; object-src 'none'; base-uri 'none'",
426        ),
427        ("early-data", "1"),
428        ("expect-ct", ""),
429        ("forwarded", ""),
430        ("if-range", ""),
431        ("origin", ""),
432        ("purpose", "prefetch"),
433        ("server", ""),
434        ("timing-allow-origin", "*"),
435        ("upgrade-insecure-requests", "1"),
436        ("user-agent", ""),
437        ("x-forwarded-for", ""),
438        ("x-frame-options", "deny"),
439        ("x-frame-options", "sameorigin"),
440    ];
441
442    fn lookup_field(index: usize) -> Option<(&'static str, &'static str)> {
443        Self::STATIC_TABLE.get(index).cloned()
444    }
445
446    fn lookup_index<K, V>(key: K, value: V) -> Option<LookupIndexFound>
447    where
448        K: AsRef<str>,
449        V: AsRef<str>,
450    {
451        Self::STATIC_TABLE
452            .iter()
453            .enumerate()
454            .find(|(_index, entry)| key.as_ref() == entry.0)
455            .map(|(index, entry)| {
456                if value.as_ref() == entry.1 {
457                    LookupIndexFound::KeyValue(index)
458                } else {
459                    LookupIndexFound::KeyOnly(index)
460                }
461            })
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use rand::random;
469    use rand::rng;
470    use rand::Rng;
471
472    #[test]
473    fn decode_field_line_type() {
474        for i in 0..=u8::MAX {
475            Decoder::decode_field_line_type(i);
476        }
477    }
478
479    #[test]
480    fn encode_decode() {
481        let headers = HashMap::from([
482            ("key1", "value1"),
483            (":status", "200"),
484            ("key2", "value2"),
485            (":status", "not_found"),
486        ]);
487
488        let enc_data = Encoder::encode(&headers);
489        let headers_dec = Decoder::decode(enc_data).unwrap();
490        let headers_dec = headers_dec
491            .iter()
492            .map(|(k, v)| (k.as_str(), v.as_str()))
493            .collect();
494
495        assert_eq!(headers, headers_dec);
496    }
497
498    #[test]
499    fn integer() {
500        const PREFIX_LEN: usize = 5;
501
502        let mut buffer = Vec::new();
503
504        for _ in 0..1_000_000 {
505            buffer.clear();
506
507            let flags = random::<u8>() & ((0x1 << (8 - PREFIX_LEN)) - 1);
508            let value = random::<u64>() as usize;
509
510            Encoder::encode_integer::<PREFIX_LEN, _>(flags, value, &mut buffer).unwrap();
511
512            let (flags_dec, value_dec) =
513                Decoder::decode_integer::<PREFIX_LEN, _>(&mut buffer.as_slice()).unwrap();
514
515            assert_eq!(flags, flags_dec);
516            assert_eq!(value, value_dec);
517        }
518    }
519
520    #[test]
521    fn integer_max() {
522        let mut buffer = Vec::new();
523        Encoder::encode_integer::<1, _>(0, usize::MAX, &mut buffer).unwrap();
524        let (_, value) = Decoder::decode_integer::<1, _>(&mut buffer.as_slice()).unwrap();
525        assert_eq!(value, usize::MAX);
526    }
527
528    #[test]
529    fn integer_overflow() {
530        let mut buffer = Vec::new();
531
532        for len in 0.. {
533            buffer.clear();
534            buffer.resize(len, 0xFF);
535
536            if let Err(DecodingError::IntegerOverflow) =
537                Decoder::decode_integer::<1, _>(&mut buffer.as_slice())
538            {
539                break;
540            }
541        }
542    }
543
544    #[test]
545    fn integer_eof() {
546        assert!(matches!(
547            Decoder::decode_integer::<1, _>(&mut [0b0000_0001].as_slice()),
548            Err(DecodingError::UnexpectedFin)
549        ));
550
551        assert!(matches!(
552            Decoder::decode_integer::<1, _>(&mut [0b0000_0001, 0b1000_0000].as_slice()),
553            Err(DecodingError::UnexpectedFin)
554        ));
555    }
556
557    #[test]
558    fn string() {
559        const PREFIX_LEN: usize = 5;
560
561        let mut buffer = Vec::new();
562
563        for _ in 0..10_000 {
564            buffer.clear();
565
566            let flags = random::<u8>() & ((0x1 << (8 - PREFIX_LEN)) - 1);
567
568            let string_len = rng().random_range(0..1024);
569            let value = rng()
570                .sample_iter(rand::distr::Alphanumeric)
571                .take(string_len)
572                .map(char::from)
573                .collect::<String>();
574
575            Encoder::encode_string::<PREFIX_LEN, _, _>(flags, &value, &mut buffer).unwrap();
576
577            let value_dec =
578                Decoder::decode_string::<PREFIX_LEN, _>(&mut buffer.as_slice()).unwrap();
579
580            assert_eq!(value, value_dec);
581        }
582    }
583}