wtransport_proto/
ids.rs

1use crate::varint::VarInt;
2use std::fmt;
3use std::str::FromStr;
4
5/// Stream id.
6#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
7pub struct StreamId(VarInt);
8
9impl StreamId {
10    /// The largest stream id.
11    pub const MAX: StreamId = StreamId(VarInt::MAX);
12
13    /// New stream id.
14    #[inline(always)]
15    pub const fn new(varint: VarInt) -> Self {
16        Self(varint)
17    }
18
19    /// Checks whether a stream is bi-directional or not.
20    #[inline(always)]
21    pub const fn is_bidirectional(self) -> bool {
22        self.0.into_inner() & 0x2 == 0
23    }
24
25    /// Checks whether a stream is client-initiated or not.
26    #[inline(always)]
27    pub const fn is_client_initiated(self) -> bool {
28        self.0.into_inner() & 0x1 == 0
29    }
30
31    /// Checks whether a stream is locally initiated or not.
32    #[inline(always)]
33    pub const fn is_local(self, is_server: bool) -> bool {
34        (self.0.into_inner() & 0x1) == (is_server as u64)
35    }
36
37    /// Returns the integer value as `u64`.
38    #[inline(always)]
39    pub const fn into_u64(self) -> u64 {
40        self.0.into_inner()
41    }
42
43    /// Returns the stream id as [`VarInt`] value.
44    #[inline(always)]
45    pub const fn into_varint(self) -> VarInt {
46        self.0
47    }
48}
49
50impl From<StreamId> for VarInt {
51    #[inline(always)]
52    fn from(stream_id: StreamId) -> Self {
53        stream_id.0
54    }
55}
56
57impl fmt::Debug for StreamId {
58    #[inline(always)]
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        self.0.fmt(f)
61    }
62}
63
64impl fmt::Display for StreamId {
65    #[inline(always)]
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        self.0.fmt(f)
68    }
69}
70
71/// Error for invalid Session ID value.
72#[derive(Debug, thiserror::Error)]
73#[error("invalid session ID")]
74pub struct InvalidSessionId;
75
76/// A WebTransport session id.
77///
78/// Internally, it corresponds to a *bidirectional* *client-initiated* QUIC stream,
79/// that is, a webtransport *session stream*.
80#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
81pub struct SessionId(StreamId);
82
83impl SessionId {
84    /// Returns the integer value as `u64`.
85    #[inline(always)]
86    pub const fn into_u64(self) -> u64 {
87        self.0.into_u64()
88    }
89
90    /// Returns the session id as [`VarInt`] value.
91    #[inline(always)]
92    pub const fn into_varint(self) -> VarInt {
93        self.0.into_varint()
94    }
95
96    /// Returns the corresponding session QUIC stream.
97    #[inline(always)]
98    pub const fn session_stream(self) -> StreamId {
99        self.0
100    }
101
102    /// Tries to create a session id from its session stream.
103    ///
104    /// `stream_id` must be *bidirectional* and *client-initiated*, otherwise
105    /// an [`Err`] is returned.
106    pub fn try_from_session_stream(stream_id: StreamId) -> Result<Self, InvalidSessionId> {
107        if stream_id.is_bidirectional() && stream_id.is_client_initiated() {
108            Ok(Self(stream_id))
109        } else {
110            Err(InvalidSessionId)
111        }
112    }
113
114    /// Creates a session id without checking session stream properties.
115    ///
116    /// # Safety
117    ///
118    /// `stream_id` must be *bidirectional* and *client-initiated*.
119    #[inline(always)]
120    pub const unsafe fn from_session_stream_unchecked(stream_id: StreamId) -> Self {
121        debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
122        Self(stream_id)
123    }
124
125    #[inline(always)]
126    pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidSessionId> {
127        Self::try_from_session_stream(StreamId::new(varint))
128    }
129
130    #[cfg(test)]
131    pub(crate) fn maybe_invalid(varint: VarInt) -> Self {
132        Self(StreamId::new(varint))
133    }
134}
135
136impl fmt::Debug for SessionId {
137    #[inline(always)]
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        self.0.fmt(f)
140    }
141}
142
143impl fmt::Display for SessionId {
144    #[inline(always)]
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        self.0.fmt(f)
147    }
148}
149
150/// Error for invalid Quarter Stream ID value (too large).
151#[derive(Debug, thiserror::Error)]
152#[error("invalid QStream ID")]
153pub struct InvalidQStreamId;
154
155/// HTTP3 Quarter Stream ID.
156#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
157pub struct QStreamId(VarInt);
158
159impl QStreamId {
160    /// The largest quarter stream id.
161    // SAFETY: value is less than max varint
162    pub const MAX: QStreamId =
163        unsafe { Self(VarInt::from_u64_unchecked(1_152_921_504_606_846_975)) };
164
165    /// Creates a quarter stream id from its corresponding [`SessionId`]
166    #[inline(always)]
167    pub const fn from_session_id(session_id: SessionId) -> Self {
168        let value = session_id.into_u64() >> 2;
169        debug_assert!(value <= Self::MAX.into_u64());
170
171        // SAFETY: after bitwise operation from stream id, result is surely a varint
172        let varint = unsafe { VarInt::from_u64_unchecked(value) };
173
174        Self(varint)
175    }
176
177    /// Returns its corresponding [`StreamId`].
178    ///
179    /// This is a *client-initiated* *bidirectional* stream.
180    #[inline(always)]
181    pub const fn into_stream_id(self) -> StreamId {
182        // SAFETY: Quarter Stream ID origin from a valid Stream ID
183        let varint = unsafe {
184            debug_assert!(self.0.into_inner() << 2 <= VarInt::MAX.into_inner());
185            VarInt::from_u64_unchecked(self.0.into_inner() << 2)
186        };
187
188        StreamId::new(varint)
189    }
190
191    /// Returns its corresponding [`SessionId`].
192    #[inline(always)]
193    pub const fn into_session_id(self) -> SessionId {
194        let stream_id = self.into_stream_id();
195
196        // SAFETY: corresponding stream for qstream is bidirectional and client-initiated
197        unsafe {
198            debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
199            SessionId::from_session_stream_unchecked(stream_id)
200        }
201    }
202
203    /// Returns the integer value as `u64`.
204    #[inline(always)]
205    pub const fn into_u64(self) -> u64 {
206        self.0.into_inner()
207    }
208
209    /// Returns the quarter stream id as [`VarInt`] value.
210    #[inline(always)]
211    pub const fn into_varint(self) -> VarInt {
212        self.0
213    }
214
215    pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidQStreamId> {
216        if varint <= Self::MAX.into_varint() {
217            Ok(Self(varint))
218        } else {
219            Err(InvalidQStreamId)
220        }
221    }
222
223    #[cfg(test)]
224    pub(crate) fn maybe_invalid(varint: VarInt) -> QStreamId {
225        Self(varint)
226    }
227}
228
229impl fmt::Debug for QStreamId {
230    #[inline(always)]
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        self.0.fmt(f)
233    }
234}
235
236impl fmt::Display for QStreamId {
237    #[inline(always)]
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        self.0.fmt(f)
240    }
241}
242
243/// Error for invalid HTTP status code.
244#[derive(Debug, thiserror::Error)]
245#[error("invalid HTTP status code")]
246pub struct InvalidStatusCode;
247
248/// HTTP status code (rfc9110).
249#[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
250pub struct StatusCode(u16);
251
252impl StatusCode {
253    /// The largest code.
254    pub const MAX: Self = Self(599);
255
256    /// The smallest code.
257    pub const MIN: Self = Self(100);
258
259    /// HTTP 200 OK status code.
260    pub const OK: Self = Self(200);
261
262    /// HTTP 403 Forbidden status code.
263    pub const FORBIDDEN: Self = Self(403);
264
265    /// HTTP 404 Not Found status code.
266    pub const NOT_FOUND: Self = Self(404);
267
268    /// HTTP 429 Too Many Requests status code.
269    pub const TOO_MANY_REQUESTS: Self = Self(429);
270
271    /// Tries to construct from `u32`.
272    #[inline(always)]
273    pub fn try_from_u32(value: u32) -> Result<Self, InvalidStatusCode> {
274        value.try_into()
275    }
276
277    /// Extracts the integer value as `u16`.
278    #[inline(always)]
279    pub fn into_inner(self) -> u16 {
280        self.0
281    }
282
283    /// Returns true if the status code is 2xx.
284    #[inline(always)]
285    pub fn is_successful(self) -> bool {
286        (200..300).contains(&self.0)
287    }
288}
289
290impl TryFrom<u8> for StatusCode {
291    type Error = InvalidStatusCode;
292
293    fn try_from(value: u8) -> Result<Self, Self::Error> {
294        if u16::from(value) >= Self::MIN.0 && u16::from(value) <= Self::MAX.0 {
295            Ok(Self(u16::from(value)))
296        } else {
297            Err(InvalidStatusCode)
298        }
299    }
300}
301
302impl TryFrom<u16> for StatusCode {
303    type Error = InvalidStatusCode;
304
305    fn try_from(value: u16) -> Result<Self, Self::Error> {
306        if (Self::MIN.0..=Self::MAX.0).contains(&value) {
307            Ok(Self(value))
308        } else {
309            Err(InvalidStatusCode)
310        }
311    }
312}
313
314impl TryFrom<u32> for StatusCode {
315    type Error = InvalidStatusCode;
316
317    fn try_from(value: u32) -> Result<Self, Self::Error> {
318        if value >= u32::from(Self::MIN.0) && value <= u32::from(Self::MAX.0) {
319            Ok(Self(value as u16))
320        } else {
321            Err(InvalidStatusCode)
322        }
323    }
324}
325
326impl TryFrom<u64> for StatusCode {
327    type Error = InvalidStatusCode;
328
329    fn try_from(value: u64) -> Result<Self, Self::Error> {
330        if value >= u64::from(Self::MIN.0) && value <= u64::from(Self::MAX.0) {
331            Ok(Self(value as u16))
332        } else {
333            Err(InvalidStatusCode)
334        }
335    }
336}
337
338impl FromStr for StatusCode {
339    type Err = InvalidStatusCode;
340
341    fn from_str(s: &str) -> Result<Self, Self::Err> {
342        Ok(Self(s.parse().map_err(|_| InvalidStatusCode)?))
343    }
344}
345
346impl fmt::Debug for StatusCode {
347    #[inline]
348    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349        self.0.fmt(f)
350    }
351}
352
353impl fmt::Display for StatusCode {
354    #[inline]
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        self.0.fmt(f)
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use utils::stream_types;
363    use utils::StreamType;
364
365    use super::*;
366
367    #[test]
368    fn stream_properties() {
369        for (id, stream_type) in stream_types(1024) {
370            let stream_id = StreamId::new(id);
371
372            match stream_type {
373                StreamType::ClientBi => {
374                    assert!(stream_id.is_bidirectional());
375                    assert!(stream_id.is_client_initiated());
376                    assert!(stream_id.is_local(false));
377                    assert!(!stream_id.is_local(true));
378                }
379                StreamType::ServerBi => {
380                    assert!(stream_id.is_bidirectional());
381                    assert!(!stream_id.is_client_initiated());
382                    assert!(!stream_id.is_local(false));
383                    assert!(stream_id.is_local(true));
384                }
385                StreamType::ClientUni => {
386                    assert!(!stream_id.is_bidirectional());
387                    assert!(stream_id.is_client_initiated());
388                    assert!(stream_id.is_local(false));
389                    assert!(!stream_id.is_local(true));
390                }
391                StreamType::ServerUni => {
392                    assert!(!stream_id.is_bidirectional());
393                    assert!(!stream_id.is_client_initiated());
394                    assert!(!stream_id.is_local(false));
395                    assert!(stream_id.is_local(true));
396                }
397            }
398        }
399    }
400
401    #[test]
402    fn session_id() {
403        for (id, stream_type) in stream_types(1024) {
404            if let StreamType::ClientBi = stream_type {
405                assert!(SessionId::try_from_varint(id).is_ok());
406                assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_ok());
407            } else {
408                assert!(SessionId::try_from_varint(id).is_err());
409                assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_err());
410            }
411        }
412    }
413
414    #[test]
415    fn qstream_id() {
416        for (quarter, id) in stream_types(1024)
417            .filter(|(_id, r#type)| matches!(r#type, StreamType::ClientBi))
418            .map(|(id, _type)| id)
419            .enumerate()
420        {
421            let session_id = SessionId::try_from_varint(id).unwrap();
422            let qstream_id = QStreamId::from_session_id(session_id);
423
424            assert_eq!(qstream_id.into_stream_id(), session_id.session_stream());
425            assert_eq!(qstream_id.into_session_id(), session_id);
426            assert_eq!(qstream_id.into_u64(), quarter as u64);
427        }
428    }
429
430    mod utils {
431        use super::*;
432
433        #[derive(Copy, Clone, Debug)]
434        pub enum StreamType {
435            ClientBi,
436            ServerBi,
437            ClientUni,
438            ServerUni,
439        }
440
441        pub fn stream_types(max_id: u32) -> impl Iterator<Item = (VarInt, StreamType)> {
442            [
443                StreamType::ClientBi,
444                StreamType::ServerBi,
445                StreamType::ClientUni,
446                StreamType::ServerUni,
447            ]
448            .into_iter()
449            .cycle()
450            .enumerate()
451            .map(|(index, r#type)| (VarInt::from_u32(index as u32), r#type))
452            .take(max_id as usize)
453        }
454    }
455}