wtransport/driver/streams/
mod.rs1use crate::driver::utils::streamid_q2w;
2use crate::driver::utils::varint_q2w;
3use crate::driver::utils::varint_w2q;
4use crate::error::ClosedStream;
5use crate::error::StreamReadError;
6use crate::error::StreamReadExactError;
7use crate::error::StreamWriteError;
8use crate::SessionId;
9use crate::StreamId;
10use crate::VarInt;
11use std::pin::Pin;
12use std::task::ready;
13use std::task::Context;
14use std::task::Poll;
15use tokio::io::ReadBuf;
16use wtransport_proto::frame::Frame;
17use wtransport_proto::session::SessionRequest;
18use wtransport_proto::stream as stream_proto;
19use wtransport_proto::stream::Stream as StreamProto;
20use wtransport_proto::stream_header::StreamHeader;
21use wtransport_proto::stream_header::StreamKind;
22
23pub type ProtoReadError = wtransport_proto::stream::IoReadError;
24pub type ProtoWriteError = wtransport_proto::stream::IoWriteError;
25
26#[derive(Debug)]
27pub struct AlreadyStop;
28
29#[derive(Debug)]
30pub struct QuicSendStream(quinn::SendStream);
31
32impl QuicSendStream {
33 #[inline(always)]
34 pub async fn write(&mut self, buf: &[u8]) -> Result<usize, StreamWriteError> {
35 let written = self.0.write(buf).await?;
36 Ok(written)
37 }
38
39 #[inline(always)]
40 pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), StreamWriteError> {
41 self.0.write_all(buf).await?;
42 Ok(())
43 }
44
45 #[inline(always)]
46 pub async fn finish(&mut self) -> Result<(), StreamWriteError> {
47 let _ = self.0.finish();
48 let result = self.stopped().await;
49 if matches!(result, StreamWriteError::Closed) {
50 Ok(())
51 } else {
52 Err(result)
53 }
54 }
55
56 #[inline(always)]
57 pub fn set_priority(&self, priority: i32) {
58 let _ = self.0.set_priority(priority);
59 }
60
61 #[inline(always)]
62 pub fn priority(&self) -> i32 {
66 self.0.priority().expect("Stream has been reset")
67 }
68
69 pub async fn stopped(&mut self) -> StreamWriteError {
70 match self.0.stopped().await {
71 Ok(None) => StreamWriteError::Closed,
72 Ok(Some(code)) => StreamWriteError::Stopped(varint_q2w(code)),
73 Err(quinn::StoppedError::ConnectionLost(_)) => StreamWriteError::NotConnected,
74 Err(quinn::StoppedError::ZeroRttRejected) => StreamWriteError::QuicProto,
75 }
76 }
77
78 #[inline(always)]
79 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
80 self.0
81 .reset(varint_w2q(error_code))
82 .map_err(|_| ClosedStream)
83 }
84
85 #[inline(always)]
86 pub fn id(&self) -> StreamId {
87 streamid_q2w(self.0.id())
88 }
89
90 #[cfg(feature = "quinn")]
91 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
92 #[inline(always)]
93 pub fn quic_stream(&self) -> &quinn::SendStream {
94 &self.0
95 }
96
97 #[cfg(feature = "quinn")]
98 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
99 #[inline(always)]
100 pub fn quic_stream_mut(&mut self) -> &mut quinn::SendStream {
101 &mut self.0
102 }
103}
104
105impl wtransport_proto::bytes::AsyncWrite for QuicSendStream {
106 #[inline(always)]
107 fn poll_write(
108 mut self: Pin<&mut Self>,
109 cx: &mut Context<'_>,
110 buf: &[u8],
111 ) -> Poll<std::io::Result<usize>> {
112 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
113 }
114}
115
116impl tokio::io::AsyncWrite for QuicSendStream {
117 #[inline(always)]
118 fn poll_write(
119 mut self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &[u8],
122 ) -> Poll<Result<usize, std::io::Error>> {
123 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
124 }
125
126 #[inline(always)]
127 fn poll_flush(
128 mut self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 ) -> Poll<Result<(), std::io::Error>> {
131 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
132 }
133
134 #[inline(always)]
135 fn poll_shutdown(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 ) -> Poll<Result<(), std::io::Error>> {
139 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
140 }
141
142 #[inline(always)]
143 fn poll_write_vectored(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 bufs: &[std::io::IoSlice<'_>],
147 ) -> Poll<Result<usize, std::io::Error>> {
148 tokio::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs)
149 }
150
151 fn is_write_vectored(&self) -> bool {
152 tokio::io::AsyncWrite::is_write_vectored(&self.0)
153 }
154}
155
156#[derive(Debug)]
157pub struct QuicRecvStream(quinn::RecvStream);
158
159impl QuicRecvStream {
160 #[inline(always)]
161 pub async fn read(&mut self, buf: &mut [u8]) -> Result<Option<usize>, StreamReadError> {
162 match self.0.read(buf).await? {
163 Some(read) => Ok(Some(read)),
164 None => Ok(None),
165 }
166 }
167
168 #[inline(always)]
169 pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), StreamReadExactError> {
170 self.0
171 .read_exact(buf)
172 .await
173 .map_err(|quic_error| match quic_error {
174 quinn::ReadExactError::FinishedEarly(read) => {
175 StreamReadExactError::FinishedEarly(read)
176 }
177 quinn::ReadExactError::ReadError(read) => StreamReadExactError::Read(read.into()),
178 })
179 }
180
181 #[inline(always)]
182 pub fn stop(&mut self, error_code: VarInt) -> Result<(), AlreadyStop> {
183 self.0.stop(varint_w2q(error_code)).map_err(|_| AlreadyStop)
184 }
185
186 #[inline(always)]
187 pub fn id(&self) -> StreamId {
188 streamid_q2w(self.0.id())
189 }
190
191 #[cfg(feature = "quinn")]
192 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
193 #[inline(always)]
194 pub fn quic_stream(&self) -> &quinn::RecvStream {
195 &self.0
196 }
197
198 #[cfg(feature = "quinn")]
199 #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
200 #[inline(always)]
201 pub fn quic_stream_mut(&mut self) -> &mut quinn::RecvStream {
202 &mut self.0
203 }
204}
205
206impl wtransport_proto::bytes::AsyncRead for QuicRecvStream {
207 #[inline(always)]
208 fn poll_read(
209 mut self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 buf: &mut [u8],
212 ) -> Poll<std::io::Result<usize>> {
213 let mut buffer = ReadBuf::new(buf);
214
215 match ready!(tokio::io::AsyncRead::poll_read(
216 Pin::new(&mut self.0),
217 cx,
218 &mut buffer
219 )) {
220 Ok(()) => Poll::Ready(Ok(buffer.filled().len())),
221 Err(io_error) => Poll::Ready(Err(io_error)),
222 }
223 }
224}
225
226impl tokio::io::AsyncRead for QuicRecvStream {
227 #[inline(always)]
228 fn poll_read(
229 mut self: Pin<&mut Self>,
230 cx: &mut Context<'_>,
231 buf: &mut ReadBuf<'_>,
232 ) -> Poll<std::io::Result<()>> {
233 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
234 }
235}
236
237#[derive(Debug)]
238pub struct Stream<S, P> {
239 stream: S,
240 proto: P,
241}
242
243pub mod biremote {
244 use super::*;
245
246 pub type StreamBiRemoteQuic =
247 Stream<(QuicSendStream, QuicRecvStream), stream_proto::biremote::StreamBiRemoteQuic>;
248
249 pub type StreamBiRemoteH3 =
250 Stream<(QuicSendStream, QuicRecvStream), stream_proto::biremote::StreamBiRemoteH3>;
251
252 pub type StreamBiRemoteWT =
253 Stream<(QuicSendStream, QuicRecvStream), stream_proto::biremote::StreamBiRemoteWT>;
254
255 impl StreamBiRemoteQuic {
256 pub async fn accept_bi(quic_connection: &quinn::Connection) -> Option<Self> {
257 let stream = quic_connection.accept_bi().await.ok()?;
258 Some(Self {
259 stream: (QuicSendStream(stream.0), QuicRecvStream(stream.1)),
260 proto: StreamProto::accept_bi(),
261 })
262 }
263
264 pub fn upgrade(self) -> StreamBiRemoteH3 {
265 StreamBiRemoteH3 {
266 stream: self.stream,
267 proto: self.proto.upgrade(),
268 }
269 }
270
271 #[inline(always)]
272 pub fn id(&self) -> StreamId {
273 self.stream.0.id()
274 }
275 }
276
277 impl StreamBiRemoteH3 {
278 pub async fn read_frame<'a>(&mut self) -> Result<Frame<'a>, ProtoReadError> {
279 self.proto.read_frame_async(&mut self.stream.1).await
280 }
281
282 pub fn stop(&mut self, error_code: VarInt) -> Result<(), AlreadyStop> {
283 self.stream.1.stop(error_code)
284 }
285
286 pub fn upgrade(self, session_id: SessionId) -> StreamBiRemoteWT {
287 StreamBiRemoteWT {
288 stream: self.stream,
289 proto: self.proto.upgrade(session_id),
290 }
291 }
292
293 pub fn id(&self) -> StreamId {
294 self.stream.0.id()
295 }
296
297 pub fn into_session(self, session_request: SessionRequest) -> session::StreamSession {
298 session::StreamSession {
299 stream: self.stream,
300 proto: self.proto.into_session(session_request),
301 }
302 }
303 }
304
305 impl StreamBiRemoteWT {
306 #[inline(always)]
307 pub fn session_id(&self) -> SessionId {
308 self.proto.session_id()
309 }
310
311 #[inline(always)]
312 pub fn id(&self) -> StreamId {
313 self.stream.0.id()
314 }
315
316 #[inline(always)]
317 pub fn into_stream(self) -> (QuicSendStream, QuicRecvStream) {
318 self.stream
319 }
320 }
321}
322
323pub mod bilocal {
324 use super::*;
325
326 pub type StreamBiLocalQuic =
327 Stream<(QuicSendStream, QuicRecvStream), stream_proto::bilocal::StreamBiLocalQuic>;
328
329 pub type StreamBiLocalH3 =
330 Stream<(QuicSendStream, QuicRecvStream), stream_proto::bilocal::StreamBiLocalH3>;
331
332 pub type StreamBiLocalWT =
333 Stream<(QuicSendStream, QuicRecvStream), stream_proto::bilocal::StreamBiLocalWT>;
334
335 impl StreamBiLocalQuic {
336 pub async fn open_bi(quic_connection: &quinn::Connection) -> Option<Self> {
337 let stream = quic_connection.open_bi().await.ok()?;
338 Some(Self {
339 stream: (QuicSendStream(stream.0), QuicRecvStream(stream.1)),
340 proto: StreamProto::open_bi(),
341 })
342 }
343
344 pub fn upgrade(self) -> StreamBiLocalH3 {
345 StreamBiLocalH3 {
346 stream: self.stream,
347 proto: self.proto.upgrade(),
348 }
349 }
350 }
351
352 impl StreamBiLocalH3 {
353 pub async fn upgrade(
354 mut self,
355 session_id: SessionId,
356 ) -> Result<StreamBiLocalWT, ProtoWriteError> {
357 let proto = self
358 .proto
359 .upgrade_async(session_id, &mut self.stream.0)
360 .await?;
361
362 Ok(StreamBiLocalWT {
363 stream: self.stream,
364 proto,
365 })
366 }
367
368 pub fn into_session(self, session_request: SessionRequest) -> session::StreamSession {
369 session::StreamSession {
370 stream: self.stream,
371 proto: self.proto.into_session(session_request),
372 }
373 }
374 }
375
376 impl StreamBiLocalWT {
377 pub fn into_stream(self) -> (QuicSendStream, QuicRecvStream) {
378 self.stream
379 }
380 }
381}
382
383pub mod uniremote {
384 use super::*;
385
386 pub type StreamUniRemoteQuic =
387 Stream<QuicRecvStream, stream_proto::uniremote::StreamUniRemoteQuic>;
388
389 pub type StreamUniRemoteH3 = Stream<QuicRecvStream, stream_proto::uniremote::StreamUniRemoteH3>;
390
391 pub type StreamUniRemoteWT = Stream<QuicRecvStream, stream_proto::uniremote::StreamUniRemoteWT>;
392
393 impl StreamUniRemoteQuic {
394 pub async fn accept_uni(quic_connection: &quinn::Connection) -> Option<Self> {
395 let stream = quic_connection.accept_uni().await.ok()?;
396 Some(Self {
397 stream: QuicRecvStream(stream),
398 proto: StreamProto::accept_uni(),
399 })
400 }
401
402 pub async fn upgrade(mut self) -> Result<StreamUniRemoteH3, ProtoReadError> {
403 let proto = self.proto.upgrade_async(&mut self.stream).await?;
404 Ok(StreamUniRemoteH3 {
405 stream: self.stream,
406 proto,
407 })
408 }
409
410 #[inline(always)]
411 pub fn id(&self) -> StreamId {
412 self.stream.id()
413 }
414 }
415
416 impl StreamUniRemoteH3 {
417 pub async fn read_frame<'a>(&mut self) -> Result<Frame<'a>, ProtoReadError> {
418 self.proto.read_frame_async(&mut self.stream).await
419 }
420
421 pub fn kind(&self) -> StreamKind {
422 self.proto.kind()
423 }
424
425 pub fn upgrade(self) -> StreamUniRemoteWT {
426 StreamUniRemoteWT {
427 stream: self.stream,
428 proto: self.proto.upgrade(),
429 }
430 }
431
432 pub fn stream_mut(&mut self) -> &mut QuicRecvStream {
433 &mut self.stream
434 }
435 }
436
437 impl StreamUniRemoteWT {
438 #[inline(always)]
439 pub fn session_id(&self) -> SessionId {
440 self.proto.session_id()
441 }
442
443 #[inline(always)]
444 pub fn id(&self) -> StreamId {
445 self.stream.id()
446 }
447
448 #[inline(always)]
449 pub fn into_stream(self) -> QuicRecvStream {
450 self.stream
451 }
452 }
453}
454
455pub mod unilocal {
456 use super::*;
457
458 pub type StreamUniLocalQuic =
459 Stream<QuicSendStream, stream_proto::unilocal::StreamUniLocalQuic>;
460
461 pub type StreamUniLocalH3 = Stream<QuicSendStream, stream_proto::unilocal::StreamUniLocalH3>;
462
463 pub type StreamUniLocalWT = Stream<QuicSendStream, stream_proto::unilocal::StreamUniLocalWT>;
464
465 impl StreamUniLocalQuic {
466 pub async fn open_uni(quic_connection: &quinn::Connection) -> Option<Self> {
467 let stream = quic_connection.open_uni().await.ok()?;
468 Some(Self {
469 stream: QuicSendStream(stream),
470 proto: StreamProto::open_uni(),
471 })
472 }
473
474 pub async fn upgrade(
475 mut self,
476 stream_header: StreamHeader,
477 ) -> Result<StreamUniLocalH3, ProtoWriteError> {
478 let proto = self
479 .proto
480 .upgrade_async(stream_header, &mut self.stream)
481 .await?;
482
483 Ok(StreamUniLocalH3 {
484 stream: self.stream,
485 proto,
486 })
487 }
488 }
489
490 impl StreamUniLocalH3 {
491 pub async fn write_frame(&mut self, frame: Frame<'_>) -> Result<(), ProtoWriteError> {
492 self.proto.write_frame_async(frame, &mut self.stream).await
493 }
494
495 pub fn kind(&self) -> StreamKind {
496 self.proto.kind()
497 }
498
499 pub async fn stopped(&mut self) -> StreamWriteError {
500 self.stream.stopped().await
501 }
502
503 pub fn upgrade(self) -> StreamUniLocalWT {
504 StreamUniLocalWT {
505 stream: self.stream,
506 proto: self.proto.upgrade(),
507 }
508 }
509 }
510
511 impl StreamUniLocalWT {
512 pub fn into_stream(self) -> QuicSendStream {
513 self.stream
514 }
515 }
516}
517
518pub mod session {
519 use super::*;
520
521 pub type StreamSession =
522 Stream<(QuicSendStream, QuicRecvStream), stream_proto::session::StreamSession>;
523
524 impl StreamSession {
525 pub async fn read_frame<'a>(&mut self) -> Result<Frame<'a>, ProtoReadError> {
526 self.proto.read_frame_async(&mut self.stream.1).await
527 }
528
529 pub async fn write_frame(&mut self, frame: Frame<'_>) -> Result<(), ProtoWriteError> {
530 self.proto
531 .write_frame_async(frame, &mut self.stream.0)
532 .await
533 }
534
535 pub fn stop(&mut self, error_code: VarInt) -> Result<(), AlreadyStop> {
536 self.stream.1.stop(error_code)
537 }
538
539 pub fn id(&self) -> StreamId {
540 self.stream.0.id()
541 }
542
543 pub fn session_id(&self) -> SessionId {
544 SessionId::try_from_session_stream(self.id()).expect("Session stream must be valid")
545 }
546
547 pub fn request(&self) -> &SessionRequest {
548 self.proto.request()
549 }
550
551 pub async fn finish(&mut self) {
552 let _ = self.stream.0.finish().await;
553 }
554
555 pub fn reset(&mut self, error_code: VarInt) {
556 let _ = self.stream.0.reset(error_code);
557 }
558 }
559}
560
561impl From<quinn::WriteError> for StreamWriteError {
562 fn from(error: quinn::WriteError) -> Self {
563 match error {
564 quinn::WriteError::Stopped(code) => StreamWriteError::Stopped(varint_q2w(code)),
565 quinn::WriteError::ConnectionLost(_) | quinn::WriteError::ClosedStream => {
566 StreamWriteError::NotConnected
567 }
568 quinn::WriteError::ZeroRttRejected => StreamWriteError::QuicProto,
569 }
570 }
571}
572
573impl From<quinn::ReadError> for StreamReadError {
574 fn from(error: quinn::ReadError) -> Self {
575 match error {
576 quinn::ReadError::Reset(code) => StreamReadError::Reset(varint_q2w(code)),
577 quinn::ReadError::ConnectionLost(_) | quinn::ReadError::ClosedStream => {
578 StreamReadError::NotConnected
579 }
580 quinn::ReadError::IllegalOrderedRead => StreamReadError::QuicProto,
581 quinn::ReadError::ZeroRttRejected => StreamReadError::QuicProto,
582 }
583 }
584}
585
586pub mod connect;
587pub mod qpack;
588pub mod settings;