wtransport/
endpoint.rs

1use crate::config::ClientConfig;
2use crate::config::DnsResolver;
3use crate::config::ServerConfig;
4use crate::connection::Connection;
5use crate::driver::streams::session::StreamSession;
6use crate::driver::streams::ProtoReadError;
7use crate::driver::streams::ProtoWriteError;
8use crate::driver::utils::varint_w2q;
9use crate::driver::Driver;
10use crate::error::ConnectingError;
11use crate::error::ConnectionError;
12use crate::VarInt;
13use quinn::TokioRuntime;
14use std::collections::HashMap;
15use std::future::Future;
16use std::future::IntoFuture;
17use std::marker::PhantomData;
18use std::net::SocketAddr;
19use std::net::SocketAddrV4;
20use std::net::SocketAddrV6;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::Context;
24use std::task::Poll;
25use tracing::debug;
26use url::Host;
27use url::Url;
28use wtransport_proto::error::ErrorCode;
29use wtransport_proto::frame::FrameKind;
30use wtransport_proto::headers::Headers;
31use wtransport_proto::session::ReservedHeader;
32use wtransport_proto::session::SessionRequest as SessionRequestProto;
33use wtransport_proto::session::SessionResponse as SessionResponseProto;
34
35/// Helper structure for Endpoint types.
36pub mod endpoint_side {
37    use super::*;
38
39    /// Type of endpoint accepting multiple WebTransport connections.
40    ///
41    /// Use [`Endpoint::server`] to create and server-endpoint.
42    pub struct Server {
43        pub(super) _marker: PhantomData<()>,
44    }
45
46    /// Type of endpoint opening a WebTransport connection.
47    ///
48    /// Use [`Endpoint::client`] to create and client-endpoint.
49    pub struct Client {
50        pub(super) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
51    }
52}
53
54/// Entrypoint for creating client or server connections.
55///
56/// A single endpoint can be used to accept or connect multiple connections.
57/// Each endpoint internally binds an UDP socket.
58///
59/// # Server
60/// Use [`Endpoint::server`] for creating a server-side endpoint.
61/// Afterwards use the method [`Endpoint::accept`] for awaiting on incoming session request.
62///
63/// ```no_run
64/// # use anyhow::Result;
65/// # use wtransport::ServerConfig;
66/// # use wtransport::Identity;
67/// use wtransport::Endpoint;
68///
69/// # async fn run() -> Result<()> {
70/// # let config = ServerConfig::builder()
71/// #       .with_bind_default(4433)
72/// #       .with_identity(Identity::self_signed(["doc"]).unwrap())
73/// #       .build();
74/// let server = Endpoint::server(config)?;
75/// loop {
76///     let incoming_session = server.accept().await;
77///     // Spawn task that handles client incoming session...
78/// }
79/// # Ok(())
80/// # }
81/// ```
82///
83/// # Client
84/// Use [`Endpoint::client`] for creating a client-side endpoint and use [`Endpoint::connect`]
85/// to connect to a server specifying the URL.
86///
87/// ```no_run
88/// # use anyhow::Result;
89/// use wtransport::ClientConfig;
90/// use wtransport::Endpoint;
91///
92/// # async fn run() -> Result<()> {
93/// let connection = Endpoint::client(ClientConfig::default())?
94///     .connect("https://localhost:4433")
95///     .await?;
96/// # Ok(())
97/// # }
98/// ```
99pub struct Endpoint<Side> {
100    endpoint: quinn::Endpoint,
101    side: Side,
102}
103
104impl<Side> Endpoint<Side> {
105    /// Closes all of this endpoint's connections immediately and cease accepting new connections.
106    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
107        self.endpoint.close(varint_w2q(error_code), reason);
108    }
109
110    /// Waits for all connections on the endpoint to be cleanly shut down.
111    pub async fn wait_idle(&self) {
112        self.endpoint.wait_idle().await;
113    }
114
115    /// Gets the local [`SocketAddr`] the underlying socket is bound to.
116    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
117        self.endpoint.local_addr()
118    }
119
120    /// Get the number of connections that are currently open.
121    pub fn open_connections(&self) -> usize {
122        self.endpoint.open_connections()
123    }
124}
125
126impl Endpoint<endpoint_side::Server> {
127    /// Constructs a *server* endpoint.
128    pub fn server(server_config: ServerConfig) -> std::io::Result<Self> {
129        let endpoint_config = server_config.endpoint_config;
130        let quic_config = server_config.quic_config;
131        let socket = server_config.bind_address_config.bind_socket()?;
132        let runtime = Arc::new(TokioRuntime);
133
134        let endpoint = quinn::Endpoint::new(endpoint_config, Some(quic_config), socket, runtime)?;
135
136        Ok(Self {
137            endpoint,
138            side: endpoint_side::Server {
139                _marker: PhantomData,
140            },
141        })
142    }
143
144    /// Get the next incoming connection attempt from a client.
145    pub async fn accept(&self) -> IncomingSession {
146        let quic_incoming = self
147            .endpoint
148            .accept()
149            .await
150            .expect("Endpoint cannot be closed");
151
152        debug!("New incoming QUIC connection");
153
154        IncomingSession(quic_incoming)
155    }
156
157    /// Reloads the server configuration.
158    ///
159    /// Useful for e.g. refreshing TLS certificates without disrupting existing connections.
160    ///
161    /// # Arguments
162    ///
163    /// * `server_config` - The new configuration for the server.
164    /// * `rebind` - A boolean indicating whether the server should rebind its socket.
165    ///              If `true`, the server will bind to a new socket with the provided configuration.
166    ///              If `false`, the bind address configuration will be ignored.
167    pub fn reload_config(&self, server_config: ServerConfig, rebind: bool) -> std::io::Result<()> {
168        if rebind {
169            let socket = server_config.bind_address_config.bind_socket()?;
170            self.endpoint.rebind(socket)?;
171        }
172
173        let quic_config = server_config.quic_config;
174        self.endpoint.set_server_config(Some(quic_config));
175
176        Ok(())
177    }
178}
179
180impl Endpoint<endpoint_side::Client> {
181    /// Constructs a *client* endpoint.
182    pub fn client(client_config: ClientConfig) -> std::io::Result<Self> {
183        let endpoint_config = client_config.endpoint_config;
184        let quic_config = client_config.quic_config;
185        let socket = client_config.bind_address_config.bind_socket()?;
186        let runtime = Arc::new(TokioRuntime);
187
188        let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket, runtime)?;
189
190        endpoint.set_default_client_config(quic_config);
191
192        Ok(Self {
193            endpoint,
194            side: endpoint_side::Client {
195                dns_resolver: client_config.dns_resolver,
196            },
197        })
198    }
199
200    /// Establishes a WebTransport connection to a specified URL.
201    ///
202    /// This method initiates a WebTransport connection to the specified URL.
203    /// It validates the URL, and performs necessary steps to establish a secure connection.
204    ///
205    /// # Arguments
206    ///
207    /// * `options` - Connection options specifying the URL and additional headers.
208    ///               It can be simply an [URL](https://en.wikipedia.org/wiki/URL) string representing
209    ///               the WebTransport endpoint to connect to. It must have an `https` scheme.
210    ///               The URL can specify either an IP address or a hostname.
211    ///               When specifying a hostname, the method will internally perform DNS resolution,
212    ///               configured with
213    ///               [`ClientConfigBuilder::dns_resolver`](crate::config::ClientConfigBuilder::dns_resolver).
214    ///
215    /// # Examples
216    ///
217    /// Connect using a URL with a hostname (DNS resolution is performed):
218    ///
219    /// ```no_run
220    /// # use anyhow::Result;
221    /// # use wtransport::endpoint::endpoint_side::Client;
222    /// # async fn example(endpoint: wtransport::Endpoint<Client>) -> Result<()> {
223    /// let url = "https://example.com:4433/webtransport";
224    /// let connection = endpoint.connect(url).await?;
225    /// # Ok(())
226    /// # }
227    /// ```
228    ///
229    /// Connect using a URL with an IP address:
230    ///
231    /// ```no_run
232    /// # use anyhow::Result;
233    /// # use wtransport::endpoint::endpoint_side::Client;
234    /// # async fn example(endpoint: wtransport::Endpoint<Client>) -> Result<()> {
235    /// let url = "https://127.0.0.1:4343/webtransport";
236    /// let connection = endpoint.connect(url).await?;
237    /// # Ok(())
238    /// # }
239    /// ```
240    ///
241    /// Connect adding an additional header:
242    ///
243    /// ```no_run
244    /// # use anyhow::Result;
245    /// # use wtransport::endpoint::endpoint_side::Client;
246    /// # use wtransport::endpoint::ConnectOptions;
247    /// # async fn example(endpoint: wtransport::Endpoint<Client>) -> Result<()> {
248    /// let options = ConnectOptions::builder("https://example.com:4433/webtransport")
249    ///     .add_header("Authorization", "AuthToken")
250    ///     .build();
251    /// let connection = endpoint.connect(options).await?;
252    /// # Ok(())
253    /// # }
254    /// ```
255    pub async fn connect<O>(&self, options: O) -> Result<Connection, ConnectingError>
256    where
257        O: IntoConnectOptions,
258    {
259        let options = options.into_options();
260
261        let url = Url::parse(&options.url)
262            .map_err(|parse_error| ConnectingError::InvalidUrl(parse_error.to_string()))?;
263
264        if url.scheme() != "https" {
265            return Err(ConnectingError::InvalidUrl(
266                "WebTransport URL scheme must be 'https'".to_string(),
267            ));
268        }
269
270        let host = url.host().expect("https scheme must have an host");
271        let port = url.port().unwrap_or(443);
272
273        let (socket_address, server_name) = match host {
274            Host::Domain(domain) => {
275                let socket_address = self
276                    .side
277                    .dns_resolver
278                    .resolve(&format!("{domain}:{port}"))
279                    .await
280                    .map_err(ConnectingError::DnsLookup)?
281                    .ok_or(ConnectingError::DnsNotFound)?;
282
283                (socket_address, domain.to_string())
284            }
285            Host::Ipv4(address) => {
286                let socket_address = SocketAddr::V4(SocketAddrV4::new(address, port));
287                (socket_address, address.to_string())
288            }
289            Host::Ipv6(address) => {
290                let socket_address = SocketAddr::V6(SocketAddrV6::new(address, port, 0, 0));
291                (socket_address, address.to_string())
292            }
293        };
294
295        let quic_connection = self
296            .endpoint
297            .connect(socket_address, &server_name)
298            .map_err(ConnectingError::with_connect_error)?
299            .await
300            .map_err(|connection_error| {
301                ConnectingError::ConnectionError(connection_error.into())
302            })?;
303
304        let driver = Driver::init(quic_connection.clone());
305
306        let _settings = driver.accept_settings().await.map_err(|driver_error| {
307            ConnectingError::ConnectionError(ConnectionError::with_driver_error(
308                driver_error,
309                &quic_connection,
310            ))
311        })?;
312
313        // TODO(biagio): validate settings
314
315        let mut session_request_proto =
316            SessionRequestProto::new(url.as_ref()).expect("Url has been already validate");
317
318        for (k, v) in options.additional_headers {
319            session_request_proto
320                .insert(k.clone(), v)
321                .map_err(|ReservedHeader| ConnectingError::ReservedHeader(k))?;
322        }
323
324        let mut stream_session = match driver.open_session(session_request_proto).await {
325            Ok(stream_session) => stream_session,
326            Err(driver_error) => {
327                return Err(ConnectingError::ConnectionError(
328                    ConnectionError::with_driver_error(driver_error, &quic_connection),
329                ))
330            }
331        };
332
333        let session_id = stream_session.session_id();
334
335        match stream_session
336            .write_frame(stream_session.request().headers().generate_frame())
337            .await
338        {
339            Ok(()) => {}
340            Err(ProtoWriteError::Stopped) => {
341                return Err(ConnectingError::SessionRejected);
342            }
343            Err(ProtoWriteError::NotConnected) => {
344                return Err(ConnectingError::with_no_connection(&quic_connection));
345            }
346        }
347
348        let frame = loop {
349            let frame = match stream_session.read_frame().await {
350                Ok(frame) => frame,
351                Err(ProtoReadError::H3(error_code)) => {
352                    quic_connection.close(varint_w2q(error_code.to_code()), b"");
353                    return Err(ConnectingError::ConnectionError(
354                        ConnectionError::local_h3_error(error_code),
355                    ));
356                }
357                Err(ProtoReadError::IO(_io_error)) => {
358                    return Err(ConnectingError::with_no_connection(&quic_connection));
359                }
360            };
361
362            if let FrameKind::Exercise(_) = frame.kind() {
363                continue;
364            }
365            break frame;
366        };
367
368        if !matches!(frame.kind(), FrameKind::Headers) {
369            quic_connection.close(varint_w2q(ErrorCode::FrameUnexpected.to_code()), b"");
370            return Err(ConnectingError::ConnectionError(
371                ConnectionError::local_h3_error(ErrorCode::FrameUnexpected),
372            ));
373        }
374
375        let headers = match Headers::with_frame(&frame) {
376            Ok(headers) => headers,
377            Err(error_code) => {
378                quic_connection.close(varint_w2q(error_code.to_code()), b"");
379                return Err(ConnectingError::ConnectionError(
380                    ConnectionError::local_h3_error(error_code),
381                ));
382            }
383        };
384
385        let session_response = match SessionResponseProto::try_from(headers) {
386            Ok(session_response) => session_response,
387            Err(_) => {
388                quic_connection.close(varint_w2q(ErrorCode::Message.to_code()), b"");
389                return Err(ConnectingError::ConnectionError(
390                    ConnectionError::local_h3_error(ErrorCode::Message),
391                ));
392            }
393        };
394
395        if session_response.code().is_successful() {
396            match driver.register_session(stream_session).await {
397                Ok(()) => {}
398                Err(driver_error) => {
399                    return Err(ConnectingError::ConnectionError(
400                        ConnectionError::with_driver_error(driver_error, &quic_connection),
401                    ))
402                }
403            }
404        } else {
405            return Err(ConnectingError::SessionRejected);
406        }
407
408        Ok(Connection::new(quic_connection, driver, session_id))
409    }
410}
411
412/// Options for establishing a client WebTransport connection.
413///
414/// Used in [`Endpoint::connect`].
415///
416/// # Examples
417///
418/// ```no_run
419/// # use anyhow::Result;
420/// # use wtransport::endpoint::endpoint_side::Client;
421/// # use wtransport::endpoint::ConnectOptions;
422/// # async fn example(endpoint: wtransport::Endpoint<Client>) -> Result<()> {
423/// let options = ConnectOptions::builder("https://example.com:4433/webtransport")
424///     .add_header("Authorization", "AuthToken")
425///     .build();
426/// let connection = endpoint.connect(options).await?;
427/// # Ok(())
428/// # }
429/// ```
430#[derive(Debug, Clone)]
431pub struct ConnectOptions {
432    url: String,
433    additional_headers: HashMap<String, String>,
434}
435
436impl ConnectOptions {
437    /// Creates a new `ConnectOptions` using a builder pattern.
438    ///
439    /// # Arguments
440    ///
441    /// * `url` - A [URL](https://en.wikipedia.org/wiki/URL) string representing the WebTransport
442    ///           endpoint to connect to. It must have an `https` scheme.
443    ///           The URL can specify either an IP address or a hostname.
444    ///           When specifying a hostname, the method will internally perform DNS resolution,
445    ///           configured with
446    ///           [`ClientConfigBuilder::dns_resolver`](crate::config::ClientConfigBuilder::dns_resolver).
447    pub fn builder<S>(url: S) -> ConnectRequestBuilder
448    where
449        S: ToString,
450    {
451        ConnectRequestBuilder {
452            url: url.to_string(),
453            additional_headers: Default::default(),
454        }
455    }
456
457    /// Gets the URL which this will connect to.
458    pub fn url(&self) -> &str {
459        &self.url
460    }
461
462    /// Gets the additional headers that will be passed when connecting.
463    pub fn additional_headers(&self) -> &HashMap<String, String> {
464        &self.additional_headers
465    }
466}
467
468/// A trait for converting types into `ConnectOptions`.
469pub trait IntoConnectOptions {
470    /// Perform value-to-value conversion into [`ConnectOptions`].
471    fn into_options(self) -> ConnectOptions;
472}
473
474/// A builder for [`ConnectOptions`].
475///
476/// See [`ConnectOptions::builder`].
477pub struct ConnectRequestBuilder {
478    url: String,
479    additional_headers: HashMap<String, String>,
480}
481
482impl ConnectRequestBuilder {
483    /// Adds a header to the connection options.
484    ///
485    /// # Examples
486    ///
487    /// ```rust
488    /// use wtransport::endpoint::ConnectOptions;
489    ///
490    /// let options = ConnectOptions::builder("https://example.com:4433/webtransport")
491    ///     .add_header("Authorization", "AuthToken")
492    ///     .build();
493    /// ```
494    pub fn add_header<K, V>(mut self, key: K, value: V) -> Self
495    where
496        K: ToString,
497        V: ToString,
498    {
499        self.additional_headers
500            .insert(key.to_string(), value.to_string());
501        self
502    }
503
504    /// Constructs the [`ConnectOptions`] from the builder configuration.
505    pub fn build(self) -> ConnectOptions {
506        ConnectOptions {
507            url: self.url,
508            additional_headers: self.additional_headers,
509        }
510    }
511}
512
513impl IntoConnectOptions for ConnectRequestBuilder {
514    fn into_options(self) -> ConnectOptions {
515        self.build()
516    }
517}
518
519impl IntoConnectOptions for ConnectOptions {
520    fn into_options(self) -> ConnectOptions {
521        self
522    }
523}
524
525impl<S> IntoConnectOptions for S
526where
527    S: ToString,
528{
529    fn into_options(self) -> ConnectOptions {
530        ConnectOptions::builder(self).build()
531    }
532}
533
534type DynFutureIncomingSession =
535    dyn Future<Output = Result<SessionRequest, ConnectionError>> + Send + Sync;
536
537/// [`IntoFuture`] for an in-progress incoming connection attempt.
538///
539/// Created by [`Endpoint::accept`].
540pub struct IncomingSession(quinn::Incoming);
541
542impl IncomingSession {
543    /// The peer's UDP address.
544    pub fn remote_address(&self) -> SocketAddr {
545        self.0.remote_address()
546    }
547
548    /// Whether the socket address that is initiating this connection has been validated.
549    ///
550    /// This means that the sender of the initial packet has proved that they can receive traffic
551    /// sent to `self.remote_address()`.
552    pub fn remote_address_validated(&self) -> bool {
553        self.0.remote_address_validated()
554    }
555
556    /// Respond with a retry packet, requiring the client to retry with address validation
557    ///
558    /// # Panics
559    ///
560    /// If `remote_address_validated()` is true.
561    pub fn retry(self) {
562        self.0.retry().expect("remote address already verified");
563    }
564
565    /// Reject this incoming connection attempt.
566    pub fn refuse(self) {
567        self.0.refuse();
568    }
569
570    /// Ignore this incoming connection attempt, not sending any packet in response.
571    pub fn ignore(self) {
572        self.0.ignore();
573    }
574}
575
576impl IntoFuture for IncomingSession {
577    type IntoFuture = IncomingSessionFuture;
578    type Output = Result<SessionRequest, ConnectionError>;
579
580    fn into_future(self) -> Self::IntoFuture {
581        IncomingSessionFuture::new(self.0)
582    }
583}
584
585/// [`Future`] for an in-progress incoming connection attempt.
586///
587/// Created by awaiting an [`IncomingSession`]
588pub struct IncomingSessionFuture(Pin<Box<DynFutureIncomingSession>>);
589
590impl IncomingSessionFuture {
591    fn new(quic_incoming: quinn::Incoming) -> Self {
592        Self(Box::pin(Self::accept(quic_incoming)))
593    }
594
595    async fn accept(quic_incoming: quinn::Incoming) -> Result<SessionRequest, ConnectionError> {
596        let quic_connection = quic_incoming.await?;
597
598        let driver = Driver::init(quic_connection.clone());
599
600        let _settings = driver.accept_settings().await.map_err(|driver_error| {
601            ConnectionError::with_driver_error(driver_error, &quic_connection)
602        })?;
603
604        // TODO(biagio): validate settings
605
606        let stream_session = driver.accept_session().await.map_err(|driver_error| {
607            ConnectionError::with_driver_error(driver_error, &quic_connection)
608        })?;
609
610        Ok(SessionRequest::new(quic_connection, driver, stream_session))
611    }
612}
613
614impl Future for IncomingSessionFuture {
615    type Output = Result<SessionRequest, ConnectionError>;
616
617    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
618        Future::poll(self.0.as_mut(), cx)
619    }
620}
621
622/// A incoming client session request.
623///
624/// Server should use methods [`accept`](Self::accept), [`forbidden`](Self::forbidden),
625/// or [`not_found`](Self::not_found) in order to validate or reject the client request.
626pub struct SessionRequest {
627    quic_connection: quinn::Connection,
628    driver: Driver,
629    stream_session: StreamSession,
630}
631
632impl SessionRequest {
633    pub(crate) fn new(
634        quic_connection: quinn::Connection,
635        driver: Driver,
636        stream_session: StreamSession,
637    ) -> Self {
638        Self {
639            quic_connection,
640            driver,
641            stream_session,
642        }
643    }
644
645    /// Returns the peer's UDP address.
646    ///
647    /// **Note**: as QUIC supports migration, remote address may change
648    /// during connection. Furthermore, when IPv6 support is enabled, IPv4
649    /// addresses may be mapped to IPv6.
650    #[inline(always)]
651    pub fn remote_address(&self) -> SocketAddr {
652        self.quic_connection.remote_address()
653    }
654
655    /// Returns the `:authority` field of the request.
656    pub fn authority(&self) -> &str {
657        self.stream_session.request().authority()
658    }
659
660    /// Returns the `:path` field of the request.
661    pub fn path(&self) -> &str {
662        self.stream_session.request().path()
663    }
664
665    /// Returns the `origin` field of the request if present.
666    pub fn origin(&self) -> Option<&str> {
667        self.stream_session.request().origin()
668    }
669
670    /// Returns the `user-agent` field of the request if present.
671    pub fn user_agent(&self) -> Option<&str> {
672        self.stream_session.request().user_agent()
673    }
674
675    /// Returns all header fields associated with the request.
676    pub fn headers(&self) -> &HashMap<String, String> {
677        self.stream_session.request().headers().as_ref()
678    }
679
680    /// Accepts the client request and it establishes the WebTransport session.
681    pub async fn accept(mut self) -> Result<Connection, ConnectionError> {
682        let response = SessionResponseProto::ok();
683
684        self.send_response(response).await?;
685
686        let session_id = self.stream_session.session_id();
687
688        self.driver
689            .register_session(self.stream_session)
690            .await
691            .map_err(|driver_error| {
692                ConnectionError::with_driver_error(driver_error, &self.quic_connection)
693            })?;
694
695        Ok(Connection::new(
696            self.quic_connection,
697            self.driver,
698            session_id,
699        ))
700    }
701
702    /// Rejects the client request by replying with `403` status code.
703    pub async fn forbidden(self) {
704        self.reject(SessionResponseProto::forbidden()).await;
705    }
706
707    /// Rejects the client request by replying with `404` status code.
708    pub async fn not_found(self) {
709        self.reject(SessionResponseProto::not_found()).await;
710    }
711
712    /// Rejects the client request by replying with `429` status code.
713    pub async fn too_many_requests(self) {
714        self.reject(SessionResponseProto::too_many_requests()).await;
715    }
716
717    async fn reject(mut self, response: SessionResponseProto) {
718        let _ = self.send_response(response).await;
719        self.stream_session.finish().await;
720    }
721
722    async fn send_response(
723        &mut self,
724        response: SessionResponseProto,
725    ) -> Result<(), ConnectionError> {
726        let frame = response.headers().generate_frame();
727
728        match self.stream_session.write_frame(frame).await {
729            Ok(()) => Ok(()),
730            Err(ProtoWriteError::NotConnected) => {
731                Err(ConnectionError::no_connect(&self.quic_connection))
732            }
733            Err(ProtoWriteError::Stopped) => {
734                self.quic_connection
735                    .close(varint_w2q(ErrorCode::ClosedCriticalStream.to_code()), b"");
736
737                Err(ConnectionError::local_h3_error(
738                    ErrorCode::ClosedCriticalStream,
739                ))
740            }
741        }
742    }
743}