1use crate::config::ClientConfig;
2use crate::config::DnsResolver;
3use crate::config::ServerConfig;
4use crate::connection::Connection;
5use crate::driver::streams::session::StreamSession;
6use crate::driver::streams::ProtoReadError;
7use crate::driver::streams::ProtoWriteError;
8use crate::driver::utils::varint_w2q;
9use crate::driver::Driver;
10use crate::error::ConnectingError;
11use crate::error::ConnectionError;
12use crate::VarInt;
13use quinn::TokioRuntime;
14use std::collections::HashMap;
15use std::future::Future;
16use std::future::IntoFuture;
17use std::marker::PhantomData;
18use std::net::SocketAddr;
19use std::net::SocketAddrV4;
20use std::net::SocketAddrV6;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::Context;
24use std::task::Poll;
25use tracing::debug;
26use url::Host;
27use url::Url;
28use wtransport_proto::error::ErrorCode;
29use wtransport_proto::frame::FrameKind;
30use wtransport_proto::headers::Headers;
31use wtransport_proto::session::ReservedHeader;
32use wtransport_proto::session::SessionRequest as SessionRequestProto;
33use wtransport_proto::session::SessionResponse as SessionResponseProto;
34
35pub mod endpoint_side {
37 use super::*;
38
39 pub struct Server {
43 pub(super) _marker: PhantomData<()>,
44 }
45
46 pub struct Client {
50 pub(super) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
51 }
52}
53
54pub struct Endpoint<Side> {
100 endpoint: quinn::Endpoint,
101 side: Side,
102}
103
104impl<Side> Endpoint<Side> {
105 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
107 self.endpoint.close(varint_w2q(error_code), reason);
108 }
109
110 pub async fn wait_idle(&self) {
112 self.endpoint.wait_idle().await;
113 }
114
115 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
117 self.endpoint.local_addr()
118 }
119
120 pub fn open_connections(&self) -> usize {
122 self.endpoint.open_connections()
123 }
124}
125
126impl Endpoint<endpoint_side::Server> {
127 pub fn server(server_config: ServerConfig) -> std::io::Result<Self> {
129 let endpoint_config = server_config.endpoint_config;
130 let quic_config = server_config.quic_config;
131 let socket = server_config.bind_address_config.bind_socket()?;
132 let runtime = Arc::new(TokioRuntime);
133
134 let endpoint = quinn::Endpoint::new(endpoint_config, Some(quic_config), socket, runtime)?;
135
136 Ok(Self {
137 endpoint,
138 side: endpoint_side::Server {
139 _marker: PhantomData,
140 },
141 })
142 }
143
144 pub async fn accept(&self) -> IncomingSession {
146 let quic_incoming = self
147 .endpoint
148 .accept()
149 .await
150 .expect("Endpoint cannot be closed");
151
152 debug!("New incoming QUIC connection");
153
154 IncomingSession(quic_incoming)
155 }
156
157 pub fn reload_config(&self, server_config: ServerConfig, rebind: bool) -> std::io::Result<()> {
168 if rebind {
169 let socket = server_config.bind_address_config.bind_socket()?;
170 self.endpoint.rebind(socket)?;
171 }
172
173 let quic_config = server_config.quic_config;
174 self.endpoint.set_server_config(Some(quic_config));
175
176 Ok(())
177 }
178}
179
180impl Endpoint<endpoint_side::Client> {
181 pub fn client(client_config: ClientConfig) -> std::io::Result<Self> {
183 let endpoint_config = client_config.endpoint_config;
184 let quic_config = client_config.quic_config;
185 let socket = client_config.bind_address_config.bind_socket()?;
186 let runtime = Arc::new(TokioRuntime);
187
188 let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket, runtime)?;
189
190 endpoint.set_default_client_config(quic_config);
191
192 Ok(Self {
193 endpoint,
194 side: endpoint_side::Client {
195 dns_resolver: client_config.dns_resolver,
196 },
197 })
198 }
199
200 pub async fn connect<O>(&self, options: O) -> Result<Connection, ConnectingError>
256 where
257 O: IntoConnectOptions,
258 {
259 let options = options.into_options();
260
261 let url = Url::parse(&options.url)
262 .map_err(|parse_error| ConnectingError::InvalidUrl(parse_error.to_string()))?;
263
264 if url.scheme() != "https" {
265 return Err(ConnectingError::InvalidUrl(
266 "WebTransport URL scheme must be 'https'".to_string(),
267 ));
268 }
269
270 let host = url.host().expect("https scheme must have an host");
271 let port = url.port().unwrap_or(443);
272
273 let (socket_address, server_name) = match host {
274 Host::Domain(domain) => {
275 let socket_address = self
276 .side
277 .dns_resolver
278 .resolve(&format!("{domain}:{port}"))
279 .await
280 .map_err(ConnectingError::DnsLookup)?
281 .ok_or(ConnectingError::DnsNotFound)?;
282
283 (socket_address, domain.to_string())
284 }
285 Host::Ipv4(address) => {
286 let socket_address = SocketAddr::V4(SocketAddrV4::new(address, port));
287 (socket_address, address.to_string())
288 }
289 Host::Ipv6(address) => {
290 let socket_address = SocketAddr::V6(SocketAddrV6::new(address, port, 0, 0));
291 (socket_address, address.to_string())
292 }
293 };
294
295 let quic_connection = self
296 .endpoint
297 .connect(socket_address, &server_name)
298 .map_err(ConnectingError::with_connect_error)?
299 .await
300 .map_err(|connection_error| {
301 ConnectingError::ConnectionError(connection_error.into())
302 })?;
303
304 let driver = Driver::init(quic_connection.clone());
305
306 let _settings = driver.accept_settings().await.map_err(|driver_error| {
307 ConnectingError::ConnectionError(ConnectionError::with_driver_error(
308 driver_error,
309 &quic_connection,
310 ))
311 })?;
312
313 let mut session_request_proto =
316 SessionRequestProto::new(url.as_ref()).expect("Url has been already validate");
317
318 for (k, v) in options.additional_headers {
319 session_request_proto
320 .insert(k.clone(), v)
321 .map_err(|ReservedHeader| ConnectingError::ReservedHeader(k))?;
322 }
323
324 let mut stream_session = match driver.open_session(session_request_proto).await {
325 Ok(stream_session) => stream_session,
326 Err(driver_error) => {
327 return Err(ConnectingError::ConnectionError(
328 ConnectionError::with_driver_error(driver_error, &quic_connection),
329 ))
330 }
331 };
332
333 let session_id = stream_session.session_id();
334
335 match stream_session
336 .write_frame(stream_session.request().headers().generate_frame())
337 .await
338 {
339 Ok(()) => {}
340 Err(ProtoWriteError::Stopped) => {
341 return Err(ConnectingError::SessionRejected);
342 }
343 Err(ProtoWriteError::NotConnected) => {
344 return Err(ConnectingError::with_no_connection(&quic_connection));
345 }
346 }
347
348 let frame = loop {
349 let frame = match stream_session.read_frame().await {
350 Ok(frame) => frame,
351 Err(ProtoReadError::H3(error_code)) => {
352 quic_connection.close(varint_w2q(error_code.to_code()), b"");
353 return Err(ConnectingError::ConnectionError(
354 ConnectionError::local_h3_error(error_code),
355 ));
356 }
357 Err(ProtoReadError::IO(_io_error)) => {
358 return Err(ConnectingError::with_no_connection(&quic_connection));
359 }
360 };
361
362 if let FrameKind::Exercise(_) = frame.kind() {
363 continue;
364 }
365 break frame;
366 };
367
368 if !matches!(frame.kind(), FrameKind::Headers) {
369 quic_connection.close(varint_w2q(ErrorCode::FrameUnexpected.to_code()), b"");
370 return Err(ConnectingError::ConnectionError(
371 ConnectionError::local_h3_error(ErrorCode::FrameUnexpected),
372 ));
373 }
374
375 let headers = match Headers::with_frame(&frame) {
376 Ok(headers) => headers,
377 Err(error_code) => {
378 quic_connection.close(varint_w2q(error_code.to_code()), b"");
379 return Err(ConnectingError::ConnectionError(
380 ConnectionError::local_h3_error(error_code),
381 ));
382 }
383 };
384
385 let session_response = match SessionResponseProto::try_from(headers) {
386 Ok(session_response) => session_response,
387 Err(_) => {
388 quic_connection.close(varint_w2q(ErrorCode::Message.to_code()), b"");
389 return Err(ConnectingError::ConnectionError(
390 ConnectionError::local_h3_error(ErrorCode::Message),
391 ));
392 }
393 };
394
395 if session_response.code().is_successful() {
396 match driver.register_session(stream_session).await {
397 Ok(()) => {}
398 Err(driver_error) => {
399 return Err(ConnectingError::ConnectionError(
400 ConnectionError::with_driver_error(driver_error, &quic_connection),
401 ))
402 }
403 }
404 } else {
405 return Err(ConnectingError::SessionRejected);
406 }
407
408 Ok(Connection::new(quic_connection, driver, session_id))
409 }
410}
411
412#[derive(Debug, Clone)]
431pub struct ConnectOptions {
432 url: String,
433 additional_headers: HashMap<String, String>,
434}
435
436impl ConnectOptions {
437 pub fn builder<S>(url: S) -> ConnectRequestBuilder
448 where
449 S: ToString,
450 {
451 ConnectRequestBuilder {
452 url: url.to_string(),
453 additional_headers: Default::default(),
454 }
455 }
456
457 pub fn url(&self) -> &str {
459 &self.url
460 }
461
462 pub fn additional_headers(&self) -> &HashMap<String, String> {
464 &self.additional_headers
465 }
466}
467
468pub trait IntoConnectOptions {
470 fn into_options(self) -> ConnectOptions;
472}
473
474pub struct ConnectRequestBuilder {
478 url: String,
479 additional_headers: HashMap<String, String>,
480}
481
482impl ConnectRequestBuilder {
483 pub fn add_header<K, V>(mut self, key: K, value: V) -> Self
495 where
496 K: ToString,
497 V: ToString,
498 {
499 self.additional_headers
500 .insert(key.to_string(), value.to_string());
501 self
502 }
503
504 pub fn build(self) -> ConnectOptions {
506 ConnectOptions {
507 url: self.url,
508 additional_headers: self.additional_headers,
509 }
510 }
511}
512
513impl IntoConnectOptions for ConnectRequestBuilder {
514 fn into_options(self) -> ConnectOptions {
515 self.build()
516 }
517}
518
519impl IntoConnectOptions for ConnectOptions {
520 fn into_options(self) -> ConnectOptions {
521 self
522 }
523}
524
525impl<S> IntoConnectOptions for S
526where
527 S: ToString,
528{
529 fn into_options(self) -> ConnectOptions {
530 ConnectOptions::builder(self).build()
531 }
532}
533
534type DynFutureIncomingSession =
535 dyn Future<Output = Result<SessionRequest, ConnectionError>> + Send + Sync;
536
537pub struct IncomingSession(quinn::Incoming);
541
542impl IncomingSession {
543 pub fn remote_address(&self) -> SocketAddr {
545 self.0.remote_address()
546 }
547
548 pub fn remote_address_validated(&self) -> bool {
553 self.0.remote_address_validated()
554 }
555
556 pub fn retry(self) {
562 self.0.retry().expect("remote address already verified");
563 }
564
565 pub fn refuse(self) {
567 self.0.refuse();
568 }
569
570 pub fn ignore(self) {
572 self.0.ignore();
573 }
574}
575
576impl IntoFuture for IncomingSession {
577 type IntoFuture = IncomingSessionFuture;
578 type Output = Result<SessionRequest, ConnectionError>;
579
580 fn into_future(self) -> Self::IntoFuture {
581 IncomingSessionFuture::new(self.0)
582 }
583}
584
585pub struct IncomingSessionFuture(Pin<Box<DynFutureIncomingSession>>);
589
590impl IncomingSessionFuture {
591 fn new(quic_incoming: quinn::Incoming) -> Self {
592 Self(Box::pin(Self::accept(quic_incoming)))
593 }
594
595 async fn accept(quic_incoming: quinn::Incoming) -> Result<SessionRequest, ConnectionError> {
596 let quic_connection = quic_incoming.await?;
597
598 let driver = Driver::init(quic_connection.clone());
599
600 let _settings = driver.accept_settings().await.map_err(|driver_error| {
601 ConnectionError::with_driver_error(driver_error, &quic_connection)
602 })?;
603
604 let stream_session = driver.accept_session().await.map_err(|driver_error| {
607 ConnectionError::with_driver_error(driver_error, &quic_connection)
608 })?;
609
610 Ok(SessionRequest::new(quic_connection, driver, stream_session))
611 }
612}
613
614impl Future for IncomingSessionFuture {
615 type Output = Result<SessionRequest, ConnectionError>;
616
617 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
618 Future::poll(self.0.as_mut(), cx)
619 }
620}
621
622pub struct SessionRequest {
627 quic_connection: quinn::Connection,
628 driver: Driver,
629 stream_session: StreamSession,
630}
631
632impl SessionRequest {
633 pub(crate) fn new(
634 quic_connection: quinn::Connection,
635 driver: Driver,
636 stream_session: StreamSession,
637 ) -> Self {
638 Self {
639 quic_connection,
640 driver,
641 stream_session,
642 }
643 }
644
645 #[inline(always)]
651 pub fn remote_address(&self) -> SocketAddr {
652 self.quic_connection.remote_address()
653 }
654
655 pub fn authority(&self) -> &str {
657 self.stream_session.request().authority()
658 }
659
660 pub fn path(&self) -> &str {
662 self.stream_session.request().path()
663 }
664
665 pub fn origin(&self) -> Option<&str> {
667 self.stream_session.request().origin()
668 }
669
670 pub fn user_agent(&self) -> Option<&str> {
672 self.stream_session.request().user_agent()
673 }
674
675 pub fn headers(&self) -> &HashMap<String, String> {
677 self.stream_session.request().headers().as_ref()
678 }
679
680 pub async fn accept(mut self) -> Result<Connection, ConnectionError> {
682 let response = SessionResponseProto::ok();
683
684 self.send_response(response).await?;
685
686 let session_id = self.stream_session.session_id();
687
688 self.driver
689 .register_session(self.stream_session)
690 .await
691 .map_err(|driver_error| {
692 ConnectionError::with_driver_error(driver_error, &self.quic_connection)
693 })?;
694
695 Ok(Connection::new(
696 self.quic_connection,
697 self.driver,
698 session_id,
699 ))
700 }
701
702 pub async fn forbidden(self) {
704 self.reject(SessionResponseProto::forbidden()).await;
705 }
706
707 pub async fn not_found(self) {
709 self.reject(SessionResponseProto::not_found()).await;
710 }
711
712 pub async fn too_many_requests(self) {
714 self.reject(SessionResponseProto::too_many_requests()).await;
715 }
716
717 async fn reject(mut self, response: SessionResponseProto) {
718 let _ = self.send_response(response).await;
719 self.stream_session.finish().await;
720 }
721
722 async fn send_response(
723 &mut self,
724 response: SessionResponseProto,
725 ) -> Result<(), ConnectionError> {
726 let frame = response.headers().generate_frame();
727
728 match self.stream_session.write_frame(frame).await {
729 Ok(()) => Ok(()),
730 Err(ProtoWriteError::NotConnected) => {
731 Err(ConnectionError::no_connect(&self.quic_connection))
732 }
733 Err(ProtoWriteError::Stopped) => {
734 self.quic_connection
735 .close(varint_w2q(ErrorCode::ClosedCriticalStream.to_code()), b"");
736
737 Err(ConnectionError::local_h3_error(
738 ErrorCode::ClosedCriticalStream,
739 ))
740 }
741 }
742 }
743}