1use crate::config::ClientConfig;
2use crate::config::DnsResolver;
3use crate::config::ServerConfig;
4use crate::connection::Connection;
5use crate::driver::streams::session::StreamSession;
6use crate::driver::streams::ProtoReadError;
7use crate::driver::streams::ProtoWriteError;
8use crate::driver::utils::varint_w2q;
9use crate::driver::Driver;
10use crate::error::ConnectingError;
11use crate::error::ConnectionError;
12use crate::VarInt;
13use quinn::TokioRuntime;
14use std::collections::HashMap;
15use std::future::Future;
16use std::future::IntoFuture;
17use std::marker::PhantomData;
18use std::net::SocketAddr;
19use std::net::SocketAddrV4;
20use std::net::SocketAddrV6;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::Context;
24use std::task::Poll;
25use tracing::debug;
26use url::Host;
27use url::Url;
28use wtransport_proto::error::ErrorCode;
29use wtransport_proto::frame::FrameKind;
30use wtransport_proto::headers::Headers;
31use wtransport_proto::session::ReservedHeader;
32use wtransport_proto::session::SessionRequest as SessionRequestProto;
33use wtransport_proto::session::SessionResponse as SessionResponseProto;
34
35pub mod endpoint_side {
37 use super::*;
38
39 pub struct Server {
43 pub(super) _marker: PhantomData<()>,
44 }
45
46 pub struct Client {
50 pub(super) dns_resolver: Arc<dyn DnsResolver + Send + Sync>,
51 }
52}
53
54pub struct Endpoint<Side> {
100 endpoint: quinn::Endpoint,
101 side: Side,
102}
103
104impl<Side> Endpoint<Side> {
105 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
107 self.endpoint.close(varint_w2q(error_code), reason);
108 }
109
110 pub async fn wait_idle(&self) {
112 self.endpoint.wait_idle().await;
113 }
114
115 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
117 self.endpoint.local_addr()
118 }
119
120 pub fn open_connections(&self) -> usize {
122 self.endpoint.open_connections()
123 }
124}
125
126impl Endpoint<endpoint_side::Server> {
127 pub fn server(server_config: ServerConfig) -> std::io::Result<Self> {
129 let endpoint_config = server_config.endpoint_config;
130 let quic_config = server_config.quic_config;
131 let socket = server_config.bind_address_config.bind_socket()?;
132 let runtime = Arc::new(TokioRuntime);
133
134 let endpoint = quinn::Endpoint::new(endpoint_config, Some(quic_config), socket, runtime)?;
135
136 Ok(Self {
137 endpoint,
138 side: endpoint_side::Server {
139 _marker: PhantomData,
140 },
141 })
142 }
143
144 pub async fn accept(&self) -> IncomingSession {
146 let quic_incoming = self
147 .endpoint
148 .accept()
149 .await
150 .expect("Endpoint cannot be closed");
151
152 debug!("New incoming QUIC connection");
153
154 IncomingSession(quic_incoming)
155 }
156
157 pub fn reload_config(&self, server_config: ServerConfig, rebind: bool) -> std::io::Result<()> {
168 if rebind {
169 let socket = server_config.bind_address_config.bind_socket()?;
170 self.endpoint.rebind(socket)?;
171 }
172
173 let quic_config = server_config.quic_config;
174 self.endpoint.set_server_config(Some(quic_config));
175
176 Ok(())
177 }
178}
179
180impl Endpoint<endpoint_side::Client> {
181 pub fn client(client_config: ClientConfig) -> std::io::Result<Self> {
183 let endpoint_config = client_config.endpoint_config;
184 let quic_config = client_config.quic_config;
185 let socket = client_config.bind_address_config.bind_socket()?;
186 let runtime = Arc::new(TokioRuntime);
187
188 let mut endpoint = quinn::Endpoint::new(endpoint_config, None, socket, runtime)?;
189
190 endpoint.set_default_client_config(quic_config);
191
192 Ok(Self {
193 endpoint,
194 side: endpoint_side::Client {
195 dns_resolver: client_config.dns_resolver,
196 },
197 })
198 }
199
200 pub async fn connect<O>(&self, options: O) -> Result<Connection, ConnectingError>
256 where
257 O: IntoConnectOptions,
258 {
259 let options = options.into_options();
260
261 let url = Url::parse(&options.url)
262 .map_err(|parse_error| ConnectingError::InvalidUrl(parse_error.to_string()))?;
263
264 if url.scheme() != "https" {
265 return Err(ConnectingError::InvalidUrl(
266 "WebTransport URL scheme must be 'https'".to_string(),
267 ));
268 }
269
270 let host = url.host().expect("https scheme must have an host");
271 let port = url.port().unwrap_or(443);
272
273 let (socket_address, server_name) = match host {
274 Host::Domain(domain) => {
275 let socket_address = self
276 .side
277 .dns_resolver
278 .resolve(&format!("{domain}:{port}"))
279 .await
280 .map_err(ConnectingError::DnsLookup)?
281 .ok_or(ConnectingError::DnsNotFound)?;
282
283 (socket_address, domain.to_string())
284 }
285 Host::Ipv4(address) => {
286 let socket_address = SocketAddr::V4(SocketAddrV4::new(address, port));
287 (socket_address, address.to_string())
288 }
289 Host::Ipv6(address) => {
290 let socket_address = SocketAddr::V6(SocketAddrV6::new(address, port, 0, 0));
291 (socket_address, address.to_string())
292 }
293 };
294
295 let quic_connection = self
296 .endpoint
297 .connect(socket_address, &server_name)
298 .map_err(ConnectingError::with_connect_error)?
299 .await
300 .map_err(|connection_error| {
301 ConnectingError::ConnectionError(connection_error.into())
302 })?;
303
304 let driver = Driver::init(quic_connection.clone());
305
306 let _settings = driver.accept_settings().await.map_err(|driver_error| {
307 ConnectingError::ConnectionError(ConnectionError::with_driver_error(
308 driver_error,
309 &quic_connection,
310 ))
311 })?;
312
313 let mut session_request_proto =
316 SessionRequestProto::new(url.as_ref()).expect("Url has been already validate");
317
318 for (k, v) in options.additional_headers {
319 session_request_proto
320 .insert(k.clone(), v)
321 .map_err(|ReservedHeader| ConnectingError::ReservedHeader(k))?;
322 }
323
324 let mut stream_session = match driver.open_session(session_request_proto).await {
325 Ok(stream_session) => stream_session,
326 Err(driver_error) => {
327 return Err(ConnectingError::ConnectionError(
328 ConnectionError::with_driver_error(driver_error, &quic_connection),
329 ))
330 }
331 };
332
333 let session_id = stream_session.session_id();
334
335 match stream_session
336 .write_frame(stream_session.request().headers().generate_frame())
337 .await
338 {
339 Ok(()) => {}
340 Err(ProtoWriteError::Stopped) => {
341 return Err(ConnectingError::SessionRejected);
342 }
343 Err(ProtoWriteError::NotConnected) => {
344 return Err(ConnectingError::with_no_connection(&quic_connection));
345 }
346 }
347
348 let frame = loop {
349 let frame = match stream_session.read_frame().await {
350 Ok(frame) => frame,
351 Err(ProtoReadError::H3(error_code)) => {
352 quic_connection.close(varint_w2q(error_code.to_code()), b"");
353 return Err(ConnectingError::ConnectionError(
354 ConnectionError::local_h3_error(error_code),
355 ));
356 }
357 Err(ProtoReadError::IO(_io_error)) => {
358 return Err(ConnectingError::with_no_connection(&quic_connection));
359 }
360 };
361
362 if let FrameKind::Exercise(_) = frame.kind() {
363 continue;
364 }
365 break frame;
366 };
367
368 if !matches!(frame.kind(), FrameKind::Headers) {
369 quic_connection.close(varint_w2q(ErrorCode::FrameUnexpected.to_code()), b"");
370 return Err(ConnectingError::ConnectionError(
371 ConnectionError::local_h3_error(ErrorCode::FrameUnexpected),
372 ));
373 }
374
375 let headers = match Headers::with_frame(&frame) {
376 Ok(headers) => headers,
377 Err(error_code) => {
378 quic_connection.close(varint_w2q(error_code.to_code()), b"");
379 return Err(ConnectingError::ConnectionError(
380 ConnectionError::local_h3_error(error_code),
381 ));
382 }
383 };
384
385 let Ok(session_response) = SessionResponseProto::try_from(headers) else {
386 quic_connection.close(varint_w2q(ErrorCode::Message.to_code()), b"");
387 return Err(ConnectingError::ConnectionError(
388 ConnectionError::local_h3_error(ErrorCode::Message),
389 ));
390 };
391
392 if session_response.code().is_successful() {
393 match driver.register_session(stream_session).await {
394 Ok(()) => {}
395 Err(driver_error) => {
396 return Err(ConnectingError::ConnectionError(
397 ConnectionError::with_driver_error(driver_error, &quic_connection),
398 ))
399 }
400 }
401 } else {
402 return Err(ConnectingError::SessionRejected);
403 }
404
405 Ok(Connection::new(quic_connection, driver, session_id))
406 }
407}
408
409#[derive(Debug, Clone)]
428pub struct ConnectOptions {
429 url: String,
430 additional_headers: HashMap<String, String>,
431}
432
433impl ConnectOptions {
434 pub fn builder<S>(url: S) -> ConnectRequestBuilder
445 where
446 S: ToString,
447 {
448 ConnectRequestBuilder {
449 url: url.to_string(),
450 additional_headers: Default::default(),
451 }
452 }
453
454 pub fn url(&self) -> &str {
456 &self.url
457 }
458
459 pub fn additional_headers(&self) -> &HashMap<String, String> {
461 &self.additional_headers
462 }
463}
464
465pub trait IntoConnectOptions {
467 fn into_options(self) -> ConnectOptions;
469}
470
471pub struct ConnectRequestBuilder {
475 url: String,
476 additional_headers: HashMap<String, String>,
477}
478
479impl ConnectRequestBuilder {
480 pub fn add_header<K, V>(mut self, key: K, value: V) -> Self
492 where
493 K: ToString,
494 V: ToString,
495 {
496 self.additional_headers
497 .insert(key.to_string(), value.to_string());
498 self
499 }
500
501 pub fn build(self) -> ConnectOptions {
503 ConnectOptions {
504 url: self.url,
505 additional_headers: self.additional_headers,
506 }
507 }
508}
509
510impl IntoConnectOptions for ConnectRequestBuilder {
511 fn into_options(self) -> ConnectOptions {
512 self.build()
513 }
514}
515
516impl IntoConnectOptions for ConnectOptions {
517 fn into_options(self) -> ConnectOptions {
518 self
519 }
520}
521
522impl<S> IntoConnectOptions for S
523where
524 S: ToString,
525{
526 fn into_options(self) -> ConnectOptions {
527 ConnectOptions::builder(self).build()
528 }
529}
530
531type DynFutureIncomingSession =
532 dyn Future<Output = Result<SessionRequest, ConnectionError>> + Send + Sync;
533
534pub struct IncomingSession(quinn::Incoming);
538
539impl IncomingSession {
540 pub fn remote_address(&self) -> SocketAddr {
542 self.0.remote_address()
543 }
544
545 pub fn remote_address_validated(&self) -> bool {
550 self.0.remote_address_validated()
551 }
552
553 pub fn retry(self) {
559 self.0.retry().expect("remote address already verified");
560 }
561
562 pub fn refuse(self) {
564 self.0.refuse();
565 }
566
567 pub fn ignore(self) {
569 self.0.ignore();
570 }
571}
572
573impl IntoFuture for IncomingSession {
574 type IntoFuture = IncomingSessionFuture;
575 type Output = Result<SessionRequest, ConnectionError>;
576
577 fn into_future(self) -> Self::IntoFuture {
578 IncomingSessionFuture::new(self.0)
579 }
580}
581
582pub struct IncomingSessionFuture(Pin<Box<DynFutureIncomingSession>>);
586
587impl IncomingSessionFuture {
588 #[cfg(feature = "quinn")]
593 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
594 pub fn with_quic_incoming(quic_incoming: quinn::Incoming) -> Self {
595 Self::new(quic_incoming)
596 }
597
598 #[cfg(feature = "quinn")]
606 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
607 pub fn with_quic_connecting(quic_connecting: quinn::Connecting) -> Self {
608 Self(Box::pin(async move {
609 let quic_connection = quic_connecting.await?;
610 Self::accept(quic_connection).await
611 }))
612 }
613
614 fn new(quic_incoming: quinn::Incoming) -> Self {
615 Self(Box::pin(async move {
616 let quic_connection = quic_incoming.await?;
617 Self::accept(quic_connection).await
618 }))
619 }
620
621 async fn accept(quic_connection: quinn::Connection) -> Result<SessionRequest, ConnectionError> {
622 let driver = Driver::init(quic_connection.clone());
623
624 let _settings = driver.accept_settings().await.map_err(|driver_error| {
625 ConnectionError::with_driver_error(driver_error, &quic_connection)
626 })?;
627
628 let stream_session = driver.accept_session().await.map_err(|driver_error| {
631 ConnectionError::with_driver_error(driver_error, &quic_connection)
632 })?;
633
634 Ok(SessionRequest::new(quic_connection, driver, stream_session))
635 }
636}
637
638impl Future for IncomingSessionFuture {
639 type Output = Result<SessionRequest, ConnectionError>;
640
641 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
642 Future::poll(self.0.as_mut(), cx)
643 }
644}
645
646pub struct SessionRequest {
651 quic_connection: quinn::Connection,
652 driver: Driver,
653 stream_session: StreamSession,
654}
655
656impl SessionRequest {
657 pub(crate) fn new(
658 quic_connection: quinn::Connection,
659 driver: Driver,
660 stream_session: StreamSession,
661 ) -> Self {
662 Self {
663 quic_connection,
664 driver,
665 stream_session,
666 }
667 }
668
669 #[inline(always)]
675 pub fn remote_address(&self) -> SocketAddr {
676 self.quic_connection.remote_address()
677 }
678
679 pub fn authority(&self) -> &str {
681 self.stream_session.request().authority()
682 }
683
684 pub fn path(&self) -> &str {
686 self.stream_session.request().path()
687 }
688
689 pub fn origin(&self) -> Option<&str> {
691 self.stream_session.request().origin()
692 }
693
694 pub fn user_agent(&self) -> Option<&str> {
696 self.stream_session.request().user_agent()
697 }
698
699 pub fn headers(&self) -> &HashMap<String, String> {
701 self.stream_session.request().headers().as_ref()
702 }
703
704 pub async fn accept(mut self) -> Result<Connection, ConnectionError> {
706 let response = SessionResponseProto::ok();
707
708 self.send_response(response).await?;
709
710 let session_id = self.stream_session.session_id();
711
712 self.driver
713 .register_session(self.stream_session)
714 .await
715 .map_err(|driver_error| {
716 ConnectionError::with_driver_error(driver_error, &self.quic_connection)
717 })?;
718
719 Ok(Connection::new(
720 self.quic_connection,
721 self.driver,
722 session_id,
723 ))
724 }
725
726 pub async fn forbidden(self) {
728 self.reject(SessionResponseProto::forbidden()).await;
729 }
730
731 pub async fn not_found(self) {
733 self.reject(SessionResponseProto::not_found()).await;
734 }
735
736 pub async fn too_many_requests(self) {
738 self.reject(SessionResponseProto::too_many_requests()).await;
739 }
740
741 async fn reject(mut self, response: SessionResponseProto) {
742 let _ = self.send_response(response).await;
743 self.stream_session.finish().await;
744 }
745
746 async fn send_response(
747 &mut self,
748 response: SessionResponseProto,
749 ) -> Result<(), ConnectionError> {
750 let frame = response.headers().generate_frame();
751
752 match self.stream_session.write_frame(frame).await {
753 Ok(()) => Ok(()),
754 Err(ProtoWriteError::NotConnected) => {
755 Err(ConnectionError::no_connect(&self.quic_connection))
756 }
757 Err(ProtoWriteError::Stopped) => {
758 self.quic_connection
759 .close(varint_w2q(ErrorCode::ClosedCriticalStream.to_code()), b"");
760
761 Err(ConnectionError::local_h3_error(
762 ErrorCode::ClosedCriticalStream,
763 ))
764 }
765 }
766 }
767}