wtransport_proto/
session.rs

1use crate::headers::Headers;
2use crate::ids::InvalidStatusCode;
3use crate::ids::StatusCode;
4use url::Url;
5
6/// Error when parsing URL.
7#[derive(Debug, thiserror::Error)]
8pub enum UrlParseError {
9    /// Missing host part in the URL.
10    #[error("host is missing in the URL")]
11    EmptyHost,
12
13    /// Invalid international domain name.
14    #[error("invalid international domain name")]
15    IdnaError,
16
17    /// Invalid port number.
18    #[error("invalid port number")]
19    InvalidPort,
20
21    /// Invalid IPv4 address
22    #[error("invalid IPv4 address")]
23    InvalidIpv4Address,
24
25    /// Invalid IPv6 address
26    #[error("invalid IPv6 address")]
27    InvalidIpv6Address,
28
29    /// Invalid domain character.
30    #[error("invalid domain character")]
31    InvalidDomainCharacter,
32
33    /// Relative URL without a base.
34    #[error("relative URL without a base")]
35    RelativeUrlWithoutBase,
36
37    /// Relative URL with a cannot-be-a-base base
38    #[error("relative URL with a cannot-be-a-base base")]
39    RelativeUrlWithCannotBeABaseBase,
40
41    /// A cannot-be-a-base URL doesn’t have a host to set
42    #[error("a cannot-be-a-base URL doesn’t have a host to set")]
43    SetHostOnCannotBeABaseUrl,
44
45    /// URLs more than 4 GB are not supported.
46    #[error("URLs more than 4 GB are not supported")]
47    Overflow,
48
49    /// Unknown error during URL parsing.
50    #[error("unknown error during URL parsing")]
51    Unknown,
52
53    /// WebTransport only support HTTPS method.
54    #[error("URL scheme is not 'https'")]
55    SchemeNotHttps,
56}
57
58/// Error when parsing [`Headers`].
59#[derive(Debug, thiserror::Error)]
60pub enum HeadersParseError {
61    /// Method field is missing.
62    #[error("':method' field is missing")]
63    MissingMethod,
64
65    /// Method is not 'CONNECT'.
66    #[error("':method' is not CONNECT")]
67    MethodNotConnect,
68
69    /// Scheme field is missing.
70    #[error("':scheme' field is missing")]
71    MissingScheme,
72
73    /// Scheme is not 'https'.
74    #[error("':scheme' is not 'https'")]
75    SchemeNotHttps,
76
77    /// Protocol field is missing.
78    #[error("':protocol' field is missing")]
79    MissingProtocol,
80
81    /// Protocol is not 'webtransport'.
82    #[error("':protocol' is not 'webtransport'")]
83    ProtocolNotWebTransport,
84
85    /// Authority field is missing.
86    #[error("':authority' field is missing")]
87    MissingAuthority,
88
89    /// Path field is missing.
90    #[error("':path' field is missing")]
91    MissingPath,
92
93    /// Status field is missing.
94    #[error("':status' field is missing")]
95    MissingStatusCode,
96
97    /// The status code value is not valid.
98    #[error("invalid HTTP status code")]
99    InvalidStatusCode,
100}
101
102/// An error when attempting to insert a value for a reserved header.
103///
104/// It is returned as an error when trying to insert a key-value pair into
105/// [`SessionRequest`] where the key is one of the
106/// [reserved headers](SessionRequest::RESERVED_HEADERS).
107#[derive(Debug, thiserror::Error)]
108#[error("used reserved header")]
109pub struct ReservedHeader;
110
111/// A CONNECT WebTransport request.
112#[derive(Debug)]
113pub struct SessionRequest(Headers);
114
115impl SessionRequest {
116    /// A collection of reserved headers used in the WebTransport protocol.
117    ///
118    /// Reserved headers have special significance in the WebTransport protocol and
119    /// cannot be used as additional headers with the [`insert`](Self::insert) method.
120    ///
121    /// The following headers are considered reserved:
122    /// - `:method`
123    /// - `:scheme`
124    /// - `:protocol`
125    /// - `:authority`
126    /// - `:path`
127    pub const RESERVED_HEADERS: &'static [&'static str] =
128        &[":method", ":scheme", ":protocol", ":authority", ":path"];
129
130    /// Parses an URL to build a Session request.
131    pub fn new<S>(url: S) -> Result<Self, UrlParseError>
132    where
133        S: AsRef<str>,
134    {
135        let url = Url::parse(url.as_ref()).map_err(UrlParseError::from_url_parse_error)?;
136
137        if url.scheme() != "https" {
138            return Err(UrlParseError::SchemeNotHttps);
139        }
140
141        let path = format!(
142            "{}{}",
143            url.path(),
144            url.query().map(|s| format!("?{}", s)).unwrap_or_default()
145        );
146
147        let headers = [
148            (":method", "CONNECT"),
149            (":scheme", "https"),
150            (":protocol", "webtransport"),
151            (":authority", url.authority()),
152            (":path", &path),
153        ]
154        .into_iter()
155        .collect();
156
157        Ok(Self(headers))
158    }
159
160    /// Returns the `:authority` field of the request.
161    pub fn authority(&self) -> &str {
162        self.0
163            .get(":authority")
164            .expect("Session request must contain ':authority' field")
165    }
166
167    /// Returns the `:path` field of the request.
168    pub fn path(&self) -> &str {
169        self.0
170            .get(":path")
171            .expect("Session request must contain ':path' field")
172    }
173
174    /// Returns the `origin` field of the request if present.
175    pub fn origin(&self) -> Option<&str> {
176        self.0.get("origin")
177    }
178
179    /// Returns the `user-agent` field of the request if present.
180    pub fn user_agent(&self) -> Option<&str> {
181        self.0.get("user-agent")
182    }
183
184    /// Gets a field from the request (if present).
185    pub fn get<K>(&self, key: K) -> Option<&str>
186    where
187        K: AsRef<str>,
188    {
189        self.0.get(key)
190    }
191
192    /// Inserts a key-value pair into the header map, checking for reserved headers.
193    ///
194    /// This method inserts a key-value pair into the header map after ensuring that
195    /// the specified key is not one of the [reserved headers](Self::RESERVED_HEADERS).
196    /// If the key is reserved, the method returns an `Err(ReservedHeader)` indicating
197    /// the attempt to insert a value for a reserved header.
198    ///
199    /// If the key already exists in the header map, the corresponding value is updated with
200    /// the new value.
201    pub fn insert<K, V>(&mut self, key: K, value: V) -> Result<(), ReservedHeader>
202    where
203        K: ToString,
204        V: ToString,
205    {
206        let key = key.to_string();
207
208        if Self::RESERVED_HEADERS.iter().any(|rh| rh == &key) {
209            return Err(ReservedHeader);
210        }
211
212        self.0.insert(key, value);
213        Ok(())
214    }
215
216    /// Returns the whole headers associated with the request.
217    pub fn headers(&self) -> &Headers {
218        &self.0
219    }
220}
221
222impl TryFrom<Headers> for SessionRequest {
223    type Error = HeadersParseError;
224
225    fn try_from(headers: Headers) -> Result<Self, Self::Error> {
226        if headers
227            .get(":method")
228            .ok_or(HeadersParseError::MissingMethod)?
229            != "CONNECT"
230        {
231            return Err(HeadersParseError::MethodNotConnect);
232        }
233
234        if headers
235            .get(":scheme")
236            .ok_or(HeadersParseError::MissingScheme)?
237            != "https"
238        {
239            return Err(HeadersParseError::SchemeNotHttps);
240        }
241
242        if headers
243            .get(":protocol")
244            .ok_or(HeadersParseError::MissingProtocol)?
245            != "webtransport"
246        {
247            return Err(HeadersParseError::ProtocolNotWebTransport);
248        }
249
250        headers
251            .get(":authority")
252            .ok_or(HeadersParseError::MissingAuthority)?;
253
254        headers.get(":path").ok_or(HeadersParseError::MissingPath)?;
255
256        Ok(Self(headers))
257    }
258}
259
260impl UrlParseError {
261    fn from_url_parse_error(error: url::ParseError) -> Self {
262        match error {
263            url::ParseError::EmptyHost => UrlParseError::EmptyHost,
264            url::ParseError::IdnaError => UrlParseError::IdnaError,
265            url::ParseError::InvalidPort => UrlParseError::InvalidPort,
266            url::ParseError::InvalidIpv4Address => UrlParseError::InvalidIpv4Address,
267            url::ParseError::InvalidIpv6Address => UrlParseError::InvalidIpv6Address,
268            url::ParseError::InvalidDomainCharacter => UrlParseError::InvalidDomainCharacter,
269            url::ParseError::RelativeUrlWithoutBase => UrlParseError::RelativeUrlWithoutBase,
270            url::ParseError::RelativeUrlWithCannotBeABaseBase => {
271                UrlParseError::RelativeUrlWithCannotBeABaseBase
272            }
273            url::ParseError::SetHostOnCannotBeABaseUrl => UrlParseError::SetHostOnCannotBeABaseUrl,
274            url::ParseError::Overflow => UrlParseError::Overflow,
275            _ => UrlParseError::Unknown,
276        }
277    }
278}
279
280/// A WebTransport CONNECT response.
281pub struct SessionResponse(Headers);
282
283impl SessionResponse {
284    /// Constructs from [`StatusCode`].
285    pub fn with_status_code(status_code: StatusCode) -> Self {
286        let headers = [(":status", status_code.to_string())].into_iter().collect();
287        Self(headers)
288    }
289
290    /// Constructs with [`StatusCode::OK`].
291    pub fn ok() -> Self {
292        Self::with_status_code(StatusCode::OK)
293    }
294
295    /// Constructs with [`StatusCode::FORBIDDEN`].
296    pub fn forbidden() -> Self {
297        Self::with_status_code(StatusCode::FORBIDDEN)
298    }
299
300    /// Constructs with [`StatusCode::NOT_FOUND`].
301    pub fn not_found() -> Self {
302        Self::with_status_code(StatusCode::NOT_FOUND)
303    }
304
305    /// Constructs with [`StatusCode::TOO_MANY_REQUESTS`].
306    pub fn too_many_requests() -> Self {
307        Self::with_status_code(StatusCode::TOO_MANY_REQUESTS)
308    }
309
310    /// Returns the status code.
311    pub fn code(&self) -> StatusCode {
312        self.0
313            .get(":status")
314            .expect("Status code is always present")
315            .parse()
316            .expect("Status code value must be valid")
317    }
318
319    /// Adds a header field to the response.
320    ///
321    /// If the key is already present, the value is updated.
322    pub fn add<K, V>(&mut self, key: K, value: V)
323    where
324        K: ToString,
325        V: ToString,
326    {
327        self.0.insert(key, value);
328    }
329
330    /// Returns the whole headers associated with the request.
331    pub fn headers(&self) -> &Headers {
332        &self.0
333    }
334}
335
336impl TryFrom<Headers> for SessionResponse {
337    type Error = HeadersParseError;
338
339    fn try_from(headers: Headers) -> Result<Self, Self::Error> {
340        let status_code = headers
341            .get(":status")
342            .ok_or(HeadersParseError::MissingStatusCode)?
343            .parse()
344            .map_err(|InvalidStatusCode| HeadersParseError::InvalidStatusCode)?;
345
346        Ok(Self::with_status_code(status_code))
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn parse_url() {
356        let request = SessionRequest::new("https://localhost:4433/foo/bar?p1=1&p2=2").unwrap();
357        assert_eq!(request.authority(), "localhost:4433");
358        assert_eq!(request.path(), "/foo/bar?p1=1&p2=2");
359        assert_eq!(request.get(":method").unwrap(), "CONNECT");
360        assert_eq!(request.get(":protocol").unwrap(), "webtransport");
361    }
362
363    #[test]
364    fn not_https() {
365        let error = SessionRequest::new("http://localhost:4433");
366        assert!(matches!(error, Err(UrlParseError::SchemeNotHttps)));
367    }
368
369    #[test]
370    fn parse_headers() {
371        assert!(SessionRequest::try_from(
372            [
373                (":method", "CONNECT"),
374                (":scheme", "https"),
375                (":protocol", "webtransport"),
376                (":authority", "localhost:4433"),
377                (":path", "/")
378            ]
379            .into_iter()
380            .collect::<Headers>()
381        )
382        .is_ok());
383    }
384
385    #[test]
386    fn parse_headers_error_method() {
387        assert!(matches!(
388            SessionRequest::try_from(
389                [
390                    (":scheme", "https"),
391                    (":protocol", "webtransport"),
392                    (":authority", "localhost:4433"),
393                    (":path", "/")
394                ]
395                .into_iter()
396                .collect::<Headers>()
397            ),
398            Err(HeadersParseError::MissingMethod),
399        ));
400
401        assert!(matches!(
402            SessionRequest::try_from(
403                [
404                    (":method", "GET"),
405                    (":scheme", "https"),
406                    (":protocol", "webtransport"),
407                    (":authority", "localhost:4433"),
408                    (":path", "/")
409                ]
410                .into_iter()
411                .collect::<Headers>()
412            ),
413            Err(HeadersParseError::MethodNotConnect),
414        ));
415    }
416
417    #[test]
418    fn parse_headers_error_scheme() {
419        assert!(matches!(
420            SessionRequest::try_from(
421                [
422                    (":method", "CONNECT"),
423                    (":protocol", "webtransport"),
424                    (":authority", "localhost:4433"),
425                    (":path", "/")
426                ]
427                .into_iter()
428                .collect::<Headers>()
429            ),
430            Err(HeadersParseError::MissingScheme),
431        ));
432
433        assert!(matches!(
434            SessionRequest::try_from(
435                [
436                    (":method", "CONNECT"),
437                    (":scheme", "http"),
438                    (":protocol", "webtransport"),
439                    (":authority", "localhost:4433"),
440                    (":path", "/")
441                ]
442                .into_iter()
443                .collect::<Headers>()
444            ),
445            Err(HeadersParseError::SchemeNotHttps),
446        ));
447    }
448
449    #[test]
450    fn insert() {
451        let mut request = SessionRequest::new("https://example.com").unwrap();
452        request.insert("version", "test").unwrap();
453        assert_eq!(request.get("version").unwrap(), "test");
454    }
455
456    #[test]
457    fn insert_reserved() {
458        let mut request = SessionRequest::new("https://example.com").unwrap();
459
460        assert!(matches!(
461            request.insert(":method", "GET"),
462            Err(ReservedHeader)
463        ));
464
465        assert!(matches!(
466            request.insert(":scheme", "ftp"),
467            Err(ReservedHeader)
468        ));
469
470        assert!(matches!(
471            request.insert(":protocol", "web"),
472            Err(ReservedHeader)
473        ));
474
475        assert!(matches!(
476            request.insert(":authority", "me"),
477            Err(ReservedHeader)
478        ));
479
480        assert!(matches!(
481            request.insert(":path", "example"),
482            Err(ReservedHeader)
483        ));
484    }
485}