1use crate::headers::Headers;
2use crate::ids::InvalidStatusCode;
3use crate::ids::StatusCode;
4use url::Url;
5
6#[derive(Debug, thiserror::Error)]
8pub enum UrlParseError {
9 #[error("host is missing in the URL")]
11 EmptyHost,
12
13 #[error("invalid international domain name")]
15 IdnaError,
16
17 #[error("invalid port number")]
19 InvalidPort,
20
21 #[error("invalid IPv4 address")]
23 InvalidIpv4Address,
24
25 #[error("invalid IPv6 address")]
27 InvalidIpv6Address,
28
29 #[error("invalid domain character")]
31 InvalidDomainCharacter,
32
33 #[error("relative URL without a base")]
35 RelativeUrlWithoutBase,
36
37 #[error("relative URL with a cannot-be-a-base base")]
39 RelativeUrlWithCannotBeABaseBase,
40
41 #[error("a cannot-be-a-base URL doesn’t have a host to set")]
43 SetHostOnCannotBeABaseUrl,
44
45 #[error("URLs more than 4 GB are not supported")]
47 Overflow,
48
49 #[error("unknown error during URL parsing")]
51 Unknown,
52
53 #[error("URL scheme is not 'https'")]
55 SchemeNotHttps,
56}
57
58#[derive(Debug, thiserror::Error)]
60pub enum HeadersParseError {
61 #[error("':method' field is missing")]
63 MissingMethod,
64
65 #[error("':method' is not CONNECT")]
67 MethodNotConnect,
68
69 #[error("':scheme' field is missing")]
71 MissingScheme,
72
73 #[error("':scheme' is not 'https'")]
75 SchemeNotHttps,
76
77 #[error("':protocol' field is missing")]
79 MissingProtocol,
80
81 #[error("':protocol' is not 'webtransport'")]
83 ProtocolNotWebTransport,
84
85 #[error("':authority' field is missing")]
87 MissingAuthority,
88
89 #[error("':path' field is missing")]
91 MissingPath,
92
93 #[error("':status' field is missing")]
95 MissingStatusCode,
96
97 #[error("invalid HTTP status code")]
99 InvalidStatusCode,
100}
101
102#[derive(Debug, thiserror::Error)]
108#[error("used reserved header")]
109pub struct ReservedHeader;
110
111#[derive(Debug)]
113pub struct SessionRequest(Headers);
114
115impl SessionRequest {
116 pub const RESERVED_HEADERS: &'static [&'static str] =
128 &[":method", ":scheme", ":protocol", ":authority", ":path"];
129
130 pub fn new<S>(url: S) -> Result<Self, UrlParseError>
132 where
133 S: AsRef<str>,
134 {
135 let url = Url::parse(url.as_ref()).map_err(UrlParseError::from_url_parse_error)?;
136
137 if url.scheme() != "https" {
138 return Err(UrlParseError::SchemeNotHttps);
139 }
140
141 let path = format!(
142 "{}{}",
143 url.path(),
144 url.query().map(|s| format!("?{}", s)).unwrap_or_default()
145 );
146
147 let headers = [
148 (":method", "CONNECT"),
149 (":scheme", "https"),
150 (":protocol", "webtransport"),
151 (":authority", url.authority()),
152 (":path", &path),
153 ]
154 .into_iter()
155 .collect();
156
157 Ok(Self(headers))
158 }
159
160 pub fn authority(&self) -> &str {
162 self.0
163 .get(":authority")
164 .expect("Session request must contain ':authority' field")
165 }
166
167 pub fn path(&self) -> &str {
169 self.0
170 .get(":path")
171 .expect("Session request must contain ':path' field")
172 }
173
174 pub fn origin(&self) -> Option<&str> {
176 self.0.get("origin")
177 }
178
179 pub fn user_agent(&self) -> Option<&str> {
181 self.0.get("user-agent")
182 }
183
184 pub fn get<K>(&self, key: K) -> Option<&str>
186 where
187 K: AsRef<str>,
188 {
189 self.0.get(key)
190 }
191
192 pub fn insert<K, V>(&mut self, key: K, value: V) -> Result<(), ReservedHeader>
202 where
203 K: ToString,
204 V: ToString,
205 {
206 let key = key.to_string();
207
208 if Self::RESERVED_HEADERS.iter().any(|rh| rh == &key) {
209 return Err(ReservedHeader);
210 }
211
212 self.0.insert(key, value);
213 Ok(())
214 }
215
216 pub fn headers(&self) -> &Headers {
218 &self.0
219 }
220}
221
222impl TryFrom<Headers> for SessionRequest {
223 type Error = HeadersParseError;
224
225 fn try_from(headers: Headers) -> Result<Self, Self::Error> {
226 if headers
227 .get(":method")
228 .ok_or(HeadersParseError::MissingMethod)?
229 != "CONNECT"
230 {
231 return Err(HeadersParseError::MethodNotConnect);
232 }
233
234 if headers
235 .get(":scheme")
236 .ok_or(HeadersParseError::MissingScheme)?
237 != "https"
238 {
239 return Err(HeadersParseError::SchemeNotHttps);
240 }
241
242 if headers
243 .get(":protocol")
244 .ok_or(HeadersParseError::MissingProtocol)?
245 != "webtransport"
246 {
247 return Err(HeadersParseError::ProtocolNotWebTransport);
248 }
249
250 headers
251 .get(":authority")
252 .ok_or(HeadersParseError::MissingAuthority)?;
253
254 headers.get(":path").ok_or(HeadersParseError::MissingPath)?;
255
256 Ok(Self(headers))
257 }
258}
259
260impl UrlParseError {
261 fn from_url_parse_error(error: url::ParseError) -> Self {
262 match error {
263 url::ParseError::EmptyHost => UrlParseError::EmptyHost,
264 url::ParseError::IdnaError => UrlParseError::IdnaError,
265 url::ParseError::InvalidPort => UrlParseError::InvalidPort,
266 url::ParseError::InvalidIpv4Address => UrlParseError::InvalidIpv4Address,
267 url::ParseError::InvalidIpv6Address => UrlParseError::InvalidIpv6Address,
268 url::ParseError::InvalidDomainCharacter => UrlParseError::InvalidDomainCharacter,
269 url::ParseError::RelativeUrlWithoutBase => UrlParseError::RelativeUrlWithoutBase,
270 url::ParseError::RelativeUrlWithCannotBeABaseBase => {
271 UrlParseError::RelativeUrlWithCannotBeABaseBase
272 }
273 url::ParseError::SetHostOnCannotBeABaseUrl => UrlParseError::SetHostOnCannotBeABaseUrl,
274 url::ParseError::Overflow => UrlParseError::Overflow,
275 _ => UrlParseError::Unknown,
276 }
277 }
278}
279
280pub struct SessionResponse(Headers);
282
283impl SessionResponse {
284 pub fn with_status_code(status_code: StatusCode) -> Self {
286 let headers = [(":status", status_code.to_string())].into_iter().collect();
287 Self(headers)
288 }
289
290 pub fn ok() -> Self {
292 Self::with_status_code(StatusCode::OK)
293 }
294
295 pub fn forbidden() -> Self {
297 Self::with_status_code(StatusCode::FORBIDDEN)
298 }
299
300 pub fn not_found() -> Self {
302 Self::with_status_code(StatusCode::NOT_FOUND)
303 }
304
305 pub fn too_many_requests() -> Self {
307 Self::with_status_code(StatusCode::TOO_MANY_REQUESTS)
308 }
309
310 pub fn code(&self) -> StatusCode {
312 self.0
313 .get(":status")
314 .expect("Status code is always present")
315 .parse()
316 .expect("Status code value must be valid")
317 }
318
319 pub fn add<K, V>(&mut self, key: K, value: V)
323 where
324 K: ToString,
325 V: ToString,
326 {
327 self.0.insert(key, value);
328 }
329
330 pub fn headers(&self) -> &Headers {
332 &self.0
333 }
334}
335
336impl TryFrom<Headers> for SessionResponse {
337 type Error = HeadersParseError;
338
339 fn try_from(headers: Headers) -> Result<Self, Self::Error> {
340 let status_code = headers
341 .get(":status")
342 .ok_or(HeadersParseError::MissingStatusCode)?
343 .parse()
344 .map_err(|InvalidStatusCode| HeadersParseError::InvalidStatusCode)?;
345
346 Ok(Self::with_status_code(status_code))
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn parse_url() {
356 let request = SessionRequest::new("https://localhost:4433/foo/bar?p1=1&p2=2").unwrap();
357 assert_eq!(request.authority(), "localhost:4433");
358 assert_eq!(request.path(), "/foo/bar?p1=1&p2=2");
359 assert_eq!(request.get(":method").unwrap(), "CONNECT");
360 assert_eq!(request.get(":protocol").unwrap(), "webtransport");
361 }
362
363 #[test]
364 fn not_https() {
365 let error = SessionRequest::new("http://localhost:4433");
366 assert!(matches!(error, Err(UrlParseError::SchemeNotHttps)));
367 }
368
369 #[test]
370 fn parse_headers() {
371 assert!(SessionRequest::try_from(
372 [
373 (":method", "CONNECT"),
374 (":scheme", "https"),
375 (":protocol", "webtransport"),
376 (":authority", "localhost:4433"),
377 (":path", "/")
378 ]
379 .into_iter()
380 .collect::<Headers>()
381 )
382 .is_ok());
383 }
384
385 #[test]
386 fn parse_headers_error_method() {
387 assert!(matches!(
388 SessionRequest::try_from(
389 [
390 (":scheme", "https"),
391 (":protocol", "webtransport"),
392 (":authority", "localhost:4433"),
393 (":path", "/")
394 ]
395 .into_iter()
396 .collect::<Headers>()
397 ),
398 Err(HeadersParseError::MissingMethod),
399 ));
400
401 assert!(matches!(
402 SessionRequest::try_from(
403 [
404 (":method", "GET"),
405 (":scheme", "https"),
406 (":protocol", "webtransport"),
407 (":authority", "localhost:4433"),
408 (":path", "/")
409 ]
410 .into_iter()
411 .collect::<Headers>()
412 ),
413 Err(HeadersParseError::MethodNotConnect),
414 ));
415 }
416
417 #[test]
418 fn parse_headers_error_scheme() {
419 assert!(matches!(
420 SessionRequest::try_from(
421 [
422 (":method", "CONNECT"),
423 (":protocol", "webtransport"),
424 (":authority", "localhost:4433"),
425 (":path", "/")
426 ]
427 .into_iter()
428 .collect::<Headers>()
429 ),
430 Err(HeadersParseError::MissingScheme),
431 ));
432
433 assert!(matches!(
434 SessionRequest::try_from(
435 [
436 (":method", "CONNECT"),
437 (":scheme", "http"),
438 (":protocol", "webtransport"),
439 (":authority", "localhost:4433"),
440 (":path", "/")
441 ]
442 .into_iter()
443 .collect::<Headers>()
444 ),
445 Err(HeadersParseError::SchemeNotHttps),
446 ));
447 }
448
449 #[test]
450 fn insert() {
451 let mut request = SessionRequest::new("https://example.com").unwrap();
452 request.insert("version", "test").unwrap();
453 assert_eq!(request.get("version").unwrap(), "test");
454 }
455
456 #[test]
457 fn insert_reserved() {
458 let mut request = SessionRequest::new("https://example.com").unwrap();
459
460 assert!(matches!(
461 request.insert(":method", "GET"),
462 Err(ReservedHeader)
463 ));
464
465 assert!(matches!(
466 request.insert(":scheme", "ftp"),
467 Err(ReservedHeader)
468 ));
469
470 assert!(matches!(
471 request.insert(":protocol", "web"),
472 Err(ReservedHeader)
473 ));
474
475 assert!(matches!(
476 request.insert(":authority", "me"),
477 Err(ReservedHeader)
478 ));
479
480 assert!(matches!(
481 request.insert(":path", "example"),
482 Err(ReservedHeader)
483 ));
484 }
485}