wtransport/
stream.rs

1use crate::driver::streams::bilocal::StreamBiLocalQuic;
2use crate::driver::streams::unilocal::StreamUniLocalQuic;
3use crate::driver::streams::ProtoWriteError;
4use crate::driver::streams::QuicRecvStream;
5use crate::driver::streams::QuicSendStream;
6use crate::error::ClosedStream;
7use crate::error::StreamOpeningError;
8use crate::error::StreamReadError;
9use crate::error::StreamReadExactError;
10use crate::error::StreamWriteError;
11use crate::SessionId;
12use crate::StreamId;
13use crate::VarInt;
14use std::future::Future;
15use std::pin::Pin;
16use std::task::Context;
17use std::task::Poll;
18use tokio::io::ReadBuf;
19use wtransport_proto::stream_header::StreamHeader;
20
21/// A stream that can only be used to send data.
22#[derive(Debug)]
23pub struct SendStream(QuicSendStream);
24
25impl SendStream {
26    #[inline(always)]
27    pub(crate) fn new(stream: QuicSendStream) -> Self {
28        Self(stream)
29    }
30
31    /// Writes bytes to the stream.
32    ///
33    /// On success, returns the number of bytes written.
34    /// Congestion and flow control may cause this to be shorter than `buf.len()`,
35    /// indicating that only a prefix of `buf` was written.
36    #[inline(always)]
37    pub async fn write(&mut self, buf: &[u8]) -> Result<usize, StreamWriteError> {
38        self.0.write(buf).await
39    }
40
41    /// Convenience method to write an entire buffer to the stream.
42    #[inline(always)]
43    pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), StreamWriteError> {
44        self.0.write_all(buf).await
45    }
46
47    /// Shuts down the send stream gracefully.
48    ///
49    /// No new data may be written after calling this method. Completes when the peer has
50    /// acknowledged all sent data, retransmitting data as needed.
51    #[inline(always)]
52    pub async fn finish(&mut self) -> Result<(), StreamWriteError> {
53        self.0.finish().await
54    }
55
56    /// Returns the [`StreamId`] associated.
57    #[inline(always)]
58    pub fn id(&self) -> StreamId {
59        self.0.id()
60    }
61
62    /// Sets the priority of the send stream.
63    ///
64    /// Every send stream has an initial priority of 0. Locally buffered data from streams with
65    /// higher priority will be transmitted before data from streams with lower priority. Changing
66    /// the priority of a stream with pending data may only take effect after that data has been
67    /// transmitted. Using many different priority levels per connection may have a negative
68    /// impact on performance.
69    #[inline(always)]
70    pub fn set_priority(&self, priority: i32) {
71        self.0.set_priority(priority);
72    }
73
74    /// Gets the priority of the send stream.
75    ///
76    /// # Panics
77    ///
78    /// If `reset` was called.
79    #[inline(always)]
80    pub fn priority(&self) -> i32 {
81        self.0.priority()
82    }
83
84    /// Closes the send stream immediately.
85    ///
86    /// No new data can be written after calling this method. Locally buffered data is dropped, and
87    /// previously transmitted data will no longer be retransmitted if lost. If an attempt has
88    /// already been made to finish the stream, the peer may still receive all written data.
89    ///
90    /// If called more than once, subsequent calls will result in [`ClosedStream`] error.
91    #[inline(always)]
92    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
93        self.0.reset(error_code)
94    }
95
96    /// Passively waits for the send stream to be stopped for any reason.
97    ///
98    /// Returns [`StreamWriteError::Closed`] if the stream was already `finish`ed or `reset`.
99    ///
100    /// Otherwise returns [`StreamWriteError::Stopped`] with an error code from the peer.
101    #[inline(always)]
102    pub async fn stopped(&mut self) -> StreamWriteError {
103        self.0.stopped().await
104    }
105
106    /// Returns a reference to the underlying QUIC stream.
107    #[cfg(feature = "quinn")]
108    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
109    #[inline(always)]
110    pub fn quic_stream(&self) -> &quinn::SendStream {
111        self.0.quic_stream()
112    }
113
114    /// Returns a mutable reference to the underlying QUIC stream.
115    #[cfg(feature = "quinn")]
116    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
117    #[inline(always)]
118    pub fn quic_stream_mut(&mut self) -> &mut quinn::SendStream {
119        self.0.quic_stream_mut()
120    }
121}
122
123/// A stream that can only be used to receive data.
124#[derive(Debug)]
125pub struct RecvStream(QuicRecvStream);
126
127impl RecvStream {
128    #[inline(always)]
129    pub(crate) fn new(stream: QuicRecvStream) -> Self {
130        Self(stream)
131    }
132
133    /// Read data contiguously from the stream.
134    ///
135    /// On success, returns the number of bytes read into `buf`.
136    #[inline(always)]
137    pub async fn read(&mut self, buf: &mut [u8]) -> Result<Option<usize>, StreamReadError> {
138        self.0.read(buf).await
139    }
140
141    /// Reads an exact number of bytes contiguously from the stream.
142    ///
143    /// If the stream terminates before the entire length has been read, it
144    /// returns [`StreamReadExactError::FinishedEarly`].
145    pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), StreamReadExactError> {
146        self.0.read_exact(buf).await
147    }
148
149    /// Stops accepting data on the stream.
150    ///
151    /// Discards unread data and notifies the peer to stop transmitting.
152    pub fn stop(mut self, error_code: VarInt) {
153        let _ = self.0.stop(error_code);
154    }
155
156    /// Returns the [`StreamId`] associated.
157    #[inline(always)]
158    pub fn id(&self) -> StreamId {
159        self.0.id()
160    }
161
162    /// Returns a reference to the underlying QUIC stream.
163    #[cfg(feature = "quinn")]
164    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
165    #[inline(always)]
166    pub fn quic_stream(&self) -> &quinn::RecvStream {
167        self.0.quic_stream()
168    }
169
170    /// Returns a mutable reference to the underlying QUIC stream.
171    #[cfg(feature = "quinn")]
172    #[cfg_attr(docsrs, doc(cfg(feature = "quinn")))]
173    #[inline(always)]
174    pub fn quic_stream_mut(&mut self) -> &mut quinn::RecvStream {
175        self.0.quic_stream_mut()
176    }
177}
178
179/// A bidirectional stream composed of [`SendStream`] and [`RecvStream`].
180///
181/// `BiStream` is a utility particularly useful in situations where a generic
182/// function or method expects a single object that must implement both
183/// [`AsyncRead`](tokio::io::AsyncRead) and [`AsyncWrite`](tokio::io::AsyncWrite).
184///
185/// # Examples
186///
187/// ```
188/// use tokio::io::AsyncRead;
189/// use tokio::io::AsyncWrite;
190/// use wtransport::stream::BiStream;
191///
192/// async fn do_operation<T>(io: T)
193/// where
194///     T: AsyncRead + AsyncWrite,
195/// {
196///     // ...
197/// }
198///
199/// # use wtransport::Connection;
200/// # async fn run(connection: Connection) {
201/// let bi_stream = BiStream::join(connection.accept_bi().await.unwrap());
202/// do_operation(bi_stream).await;
203/// # }
204/// ```
205#[derive(Debug)]
206pub struct BiStream((SendStream, RecvStream));
207
208impl BiStream {
209    /// Joins a sending stream and a receiving stream into a single `BiStream` object.
210    pub fn join(s: (SendStream, RecvStream)) -> Self {
211        Self(s)
212    }
213
214    /// Splits the bidirectional stream into its sending and receiving stream handles.
215    pub fn split(self) -> (SendStream, RecvStream) {
216        self.0
217    }
218
219    /// Returns a reference to the inner [`SendStream`].
220    pub fn send(&self) -> &SendStream {
221        &self.0 .0
222    }
223
224    /// Returns a mutable reference to the inner [`SendStream`].
225    pub fn send_mut(&mut self) -> &mut SendStream {
226        &mut self.0 .0
227    }
228
229    /// Returns a reference to the inner [`RecvStream`].
230    pub fn recv(&self) -> &RecvStream {
231        &self.0 .1
232    }
233
234    /// Returns a mutable reference to the inner [`RecvStream`].
235    pub fn recv_mut(&mut self) -> &mut RecvStream {
236        &mut self.0 .1
237    }
238}
239
240impl From<(SendStream, RecvStream)> for BiStream {
241    fn from(value: (SendStream, RecvStream)) -> Self {
242        Self::join(value)
243    }
244}
245
246impl From<BiStream> for (SendStream, RecvStream) {
247    fn from(value: BiStream) -> Self {
248        value.split()
249    }
250}
251
252impl tokio::io::AsyncWrite for SendStream {
253    #[inline(always)]
254    fn poll_write(
255        mut self: Pin<&mut Self>,
256        cx: &mut Context<'_>,
257        buf: &[u8],
258    ) -> Poll<std::io::Result<usize>> {
259        tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
260    }
261
262    #[inline(always)]
263    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
264        tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
265    }
266
267    #[inline(always)]
268    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
269        tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
270    }
271
272    #[inline(always)]
273    fn poll_write_vectored(
274        mut self: Pin<&mut Self>,
275        cx: &mut Context<'_>,
276        bufs: &[std::io::IoSlice<'_>],
277    ) -> Poll<Result<usize, std::io::Error>> {
278        tokio::io::AsyncWrite::poll_write_vectored(Pin::new(&mut self.0), cx, bufs)
279    }
280
281    #[inline(always)]
282    fn is_write_vectored(&self) -> bool {
283        tokio::io::AsyncWrite::is_write_vectored(&self.0)
284    }
285}
286
287impl tokio::io::AsyncRead for RecvStream {
288    #[inline(always)]
289    fn poll_read(
290        mut self: Pin<&mut Self>,
291        cx: &mut Context<'_>,
292        buf: &mut ReadBuf<'_>,
293    ) -> Poll<std::io::Result<()>> {
294        tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
295    }
296}
297
298impl tokio::io::AsyncWrite for BiStream {
299    fn poll_write(
300        mut self: Pin<&mut Self>,
301        cx: &mut Context<'_>,
302        buf: &[u8],
303    ) -> Poll<Result<usize, std::io::Error>> {
304        tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0 .0), cx, buf)
305    }
306
307    fn poll_flush(
308        mut self: Pin<&mut Self>,
309        cx: &mut Context<'_>,
310    ) -> Poll<Result<(), std::io::Error>> {
311        tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0 .0), cx)
312    }
313
314    fn poll_shutdown(
315        mut self: Pin<&mut Self>,
316        cx: &mut Context<'_>,
317    ) -> Poll<Result<(), std::io::Error>> {
318        tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0 .0), cx)
319    }
320}
321
322impl tokio::io::AsyncRead for BiStream {
323    fn poll_read(
324        mut self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326        buf: &mut ReadBuf<'_>,
327    ) -> Poll<std::io::Result<()>> {
328        tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0 .1), cx, buf)
329    }
330}
331
332type DynFutureUniStream = dyn Future<Output = Result<SendStream, StreamOpeningError>> + Send + Sync;
333
334/// [`Future`] for an in-progress opening unidirectional stream.
335///
336/// See [`Connection::open_uni`](crate::Connection::open_uni).
337pub struct OpeningUniStream(Pin<Box<DynFutureUniStream>>);
338
339impl OpeningUniStream {
340    pub(crate) fn new(session_id: SessionId, quic_stream: StreamUniLocalQuic) -> Self {
341        Self(Box::pin(async move {
342            match quic_stream
343                .upgrade(StreamHeader::new_webtransport(session_id))
344                .await
345            {
346                Ok(stream) => Ok(SendStream(stream.upgrade().into_stream())),
347                Err(ProtoWriteError::NotConnected) => Err(StreamOpeningError::NotConnected),
348                Err(ProtoWriteError::Stopped) => Err(StreamOpeningError::Refused),
349            }
350        }))
351    }
352}
353
354impl Future for OpeningUniStream {
355    type Output = Result<SendStream, StreamOpeningError>;
356
357    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
358        Future::poll(self.0.as_mut(), cx)
359    }
360}
361
362type DynFutureBiStream =
363    dyn Future<Output = Result<(SendStream, RecvStream), StreamOpeningError>> + Send + Sync;
364
365/// [`Future`] for an in-progress opening bidirectional stream.
366///
367/// See [`Connection::open_bi`](crate::Connection::open_bi).
368pub struct OpeningBiStream(Pin<Box<DynFutureBiStream>>);
369
370impl OpeningBiStream {
371    pub(crate) fn new(session_id: SessionId, quic_stream: StreamBiLocalQuic) -> Self {
372        Self(Box::pin(async move {
373            match quic_stream.upgrade().upgrade(session_id).await {
374                Ok(stream) => {
375                    let stream = stream.into_stream();
376                    Ok((SendStream::new(stream.0), RecvStream::new(stream.1)))
377                }
378                Err(ProtoWriteError::NotConnected) => Err(StreamOpeningError::NotConnected),
379                Err(ProtoWriteError::Stopped) => Err(StreamOpeningError::Refused),
380            }
381        }))
382    }
383}
384
385impl Future for OpeningBiStream {
386    type Output = Result<(SendStream, RecvStream), StreamOpeningError>;
387
388    #[inline(always)]
389    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390        Future::poll(self.0.as_mut(), cx)
391    }
392}