wrpc_runtime_wasmtime/rpc/
mod.rs

1//! `wrpc:transport` implementation
2
3use core::any::Any;
4use core::fmt;
5use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll};
9
10use std::sync::Arc;
11
12use anyhow::Context as _;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use wasmtime::component::{HasData, Linker};
15use wasmtime_wasi::p2::Pollable;
16use wrpc_transport::Invoke;
17
18use crate::{bindings, WrpcView};
19
20mod host;
21
22/// Wrapper struct, for which [crate::bindings::wrpc::transport::transport::Host] is implemented
23#[repr(transparent)]
24pub struct WrpcRpcImpl<T>(pub T);
25
26impl<T: 'static> HasData for WrpcRpcImpl<T> {
27    type Data<'a> = WrpcRpcImpl<&'a mut T>;
28}
29
30pub fn add_to_linker<T>(linker: &mut Linker<T>) -> anyhow::Result<()>
31where
32    T: WrpcView,
33    T::Invoke: Clone + 'static,
34    <T::Invoke as Invoke>::Context: 'static,
35{
36    bindings::rpc::context::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
37        .context("failed to link `wrpc:rpc/context`")?;
38    bindings::rpc::error::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
39        .context("failed to link `wrpc:rpc/error`")?;
40    bindings::rpc::invoker::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
41        .context("failed to link `wrpc:rpc/invoker`")?;
42    bindings::rpc::transport::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
43        .context("failed to link `wrpc:rpc/transport`")?;
44    Ok(())
45}
46
47/// RPC error
48pub enum Error {
49    /// Error originating from [Invoke::invoke] call
50    Invoke(anyhow::Error),
51    /// Error originating from [Index::index](wrpc_transport::Index::index) call on [Invoke::Incoming].
52    IncomingIndex(anyhow::Error),
53    /// Error originating from [Index::index](wrpc_transport::Index::index) call on
54    /// [Invoke::Outgoing].
55    OutgoingIndex(anyhow::Error),
56    /// Error originating from a `wasi:io` stream provided by this crate.
57    Stream(StreamError),
58}
59
60impl fmt::Debug for Error {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        match self {
63            Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
64                error.fmt(f)
65            }
66            Error::Stream(error) => error.fmt(f),
67        }
68    }
69}
70
71impl fmt::Display for Error {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        match self {
74            Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
75                error.fmt(f)
76            }
77            Error::Stream(error) => error.fmt(f),
78        }
79    }
80}
81
82/// Error type originating from `wasi:io` streams provided by this crate.
83pub enum StreamError {
84    LockPoisoned,
85    TypeMismatch(&'static str),
86    Read(std::io::Error),
87    Write(std::io::Error),
88    Flush(std::io::Error),
89    Shutdown(std::io::Error),
90}
91
92impl core::error::Error for StreamError {}
93
94impl fmt::Debug for StreamError {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        match self {
97            StreamError::LockPoisoned => "lock poisoned".fmt(f),
98            StreamError::TypeMismatch(error) => error.fmt(f),
99            StreamError::Read(error)
100            | StreamError::Write(error)
101            | StreamError::Flush(error)
102            | StreamError::Shutdown(error) => error.fmt(f),
103        }
104    }
105}
106
107impl fmt::Display for StreamError {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            StreamError::LockPoisoned => "lock poisoned".fmt(f),
111            StreamError::TypeMismatch(error) => error.fmt(f),
112            StreamError::Read(error)
113            | StreamError::Write(error)
114            | StreamError::Flush(error)
115            | StreamError::Shutdown(error) => error.fmt(f),
116        }
117    }
118}
119
120pub enum Invocation {
121    Future(Pin<Box<dyn Future<Output = Box<dyn Any + Send>> + Send>>),
122    Ready(Box<dyn Any + Send>),
123}
124
125#[wasmtime_wasi::async_trait]
126impl Pollable for Invocation {
127    async fn ready(&mut self) {
128        match self {
129            Self::Future(fut) => {
130                let res = fut.await;
131                *self = Self::Ready(res);
132            }
133            Self::Ready(..) => {}
134        }
135    }
136}
137
138pub struct OutgoingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
139
140pub struct IncomingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
141
142pub struct IncomingChannelStream<T> {
143    incoming: IncomingChannel,
144    _ty: PhantomData<T>,
145}
146
147impl<T: AsyncRead + Unpin + 'static> AsyncRead for IncomingChannelStream<T> {
148    fn poll_read(
149        self: Pin<&mut Self>,
150        cx: &mut Context<'_>,
151        buf: &mut ReadBuf<'_>,
152    ) -> Poll<std::io::Result<()>> {
153        let Ok(mut incoming) = self.incoming.0.write() else {
154            return Poll::Ready(Err(std::io::Error::new(
155                std::io::ErrorKind::Deadlock,
156                StreamError::LockPoisoned,
157            )));
158        };
159        let Some(incoming) = incoming.downcast_mut::<T>() else {
160            return Poll::Ready(Err(std::io::Error::new(
161                std::io::ErrorKind::InvalidData,
162                StreamError::TypeMismatch("invalid incoming channel type"),
163            )));
164        };
165        Pin::new(incoming)
166            .poll_read(cx, buf)
167            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Read(err)))
168    }
169}
170
171pub struct OutgoingChannelStream<T> {
172    outgoing: OutgoingChannel,
173    _ty: PhantomData<T>,
174}
175
176impl<T: AsyncWrite + Unpin + 'static> AsyncWrite for OutgoingChannelStream<T> {
177    fn poll_write(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<Result<usize, std::io::Error>> {
182        let Ok(mut outgoing) = self.outgoing.0.write() else {
183            return Poll::Ready(Err(std::io::Error::new(
184                std::io::ErrorKind::Deadlock,
185                StreamError::LockPoisoned,
186            )));
187        };
188        let Some(outgoing) = outgoing.downcast_mut::<T>() else {
189            return Poll::Ready(Err(std::io::Error::new(
190                std::io::ErrorKind::InvalidData,
191                StreamError::TypeMismatch("invalid outgoing channel type"),
192            )));
193        };
194        Pin::new(outgoing)
195            .poll_write(cx, buf)
196            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Write(err)))
197    }
198
199    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
200        let Ok(mut outgoing) = self.outgoing.0.write() else {
201            return Poll::Ready(Err(std::io::Error::new(
202                std::io::ErrorKind::Deadlock,
203                StreamError::LockPoisoned,
204            )));
205        };
206        let Some(outgoing) = outgoing.downcast_mut::<T>() else {
207            return Poll::Ready(Err(std::io::Error::new(
208                std::io::ErrorKind::InvalidData,
209                StreamError::TypeMismatch("invalid outgoing channel type"),
210            )));
211        };
212        Pin::new(outgoing)
213            .poll_flush(cx)
214            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Flush(err)))
215    }
216
217    fn poll_shutdown(
218        self: Pin<&mut Self>,
219        cx: &mut Context<'_>,
220    ) -> Poll<Result<(), std::io::Error>> {
221        let Ok(mut outgoing) = self.outgoing.0.write() else {
222            return Poll::Ready(Err(std::io::Error::new(
223                std::io::ErrorKind::Deadlock,
224                StreamError::LockPoisoned,
225            )));
226        };
227        let Some(outgoing) = outgoing.downcast_mut::<T>() else {
228            return Poll::Ready(Err(std::io::Error::new(
229                std::io::ErrorKind::InvalidData,
230                StreamError::TypeMismatch("invalid outgoing channel type"),
231            )));
232        };
233        Pin::new(outgoing)
234            .poll_shutdown(cx)
235            .map_err(|err| std::io::Error::new(err.kind(), StreamError::Shutdown(err)))
236    }
237}