wrpc_runtime_wasmtime/rpc/
mod.rs1use core::any::Any;
4use core::fmt;
5use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll};
9
10use std::sync::Arc;
11
12use anyhow::Context as _;
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14use wasmtime::component::{HasData, Linker};
15use wasmtime_wasi::p2::Pollable;
16use wrpc_transport::Invoke;
17
18use crate::{bindings, WrpcView};
19
20mod host;
21
22#[repr(transparent)]
24pub struct WrpcRpcImpl<T>(pub T);
25
26impl<T: 'static> HasData for WrpcRpcImpl<T> {
27 type Data<'a> = WrpcRpcImpl<&'a mut T>;
28}
29
30pub fn add_to_linker<T>(linker: &mut Linker<T>) -> anyhow::Result<()>
31where
32 T: WrpcView,
33 T::Invoke: Clone + 'static,
34 <T::Invoke as Invoke>::Context: 'static,
35{
36 bindings::rpc::context::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
37 .context("failed to link `wrpc:rpc/context`")?;
38 bindings::rpc::error::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
39 .context("failed to link `wrpc:rpc/error`")?;
40 bindings::rpc::invoker::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
41 .context("failed to link `wrpc:rpc/invoker`")?;
42 bindings::rpc::transport::add_to_linker::<_, WrpcRpcImpl<T>>(linker, |t| WrpcRpcImpl(t))
43 .context("failed to link `wrpc:rpc/transport`")?;
44 Ok(())
45}
46
47pub enum Error {
49 Invoke(anyhow::Error),
51 IncomingIndex(anyhow::Error),
53 OutgoingIndex(anyhow::Error),
56 Stream(StreamError),
58}
59
60impl fmt::Debug for Error {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
64 error.fmt(f)
65 }
66 Error::Stream(error) => error.fmt(f),
67 }
68 }
69}
70
71impl fmt::Display for Error {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Error::Invoke(error) | Error::IncomingIndex(error) | Error::OutgoingIndex(error) => {
75 error.fmt(f)
76 }
77 Error::Stream(error) => error.fmt(f),
78 }
79 }
80}
81
82pub enum StreamError {
84 LockPoisoned,
85 TypeMismatch(&'static str),
86 Read(std::io::Error),
87 Write(std::io::Error),
88 Flush(std::io::Error),
89 Shutdown(std::io::Error),
90}
91
92impl core::error::Error for StreamError {}
93
94impl fmt::Debug for StreamError {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 match self {
97 StreamError::LockPoisoned => "lock poisoned".fmt(f),
98 StreamError::TypeMismatch(error) => error.fmt(f),
99 StreamError::Read(error)
100 | StreamError::Write(error)
101 | StreamError::Flush(error)
102 | StreamError::Shutdown(error) => error.fmt(f),
103 }
104 }
105}
106
107impl fmt::Display for StreamError {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 StreamError::LockPoisoned => "lock poisoned".fmt(f),
111 StreamError::TypeMismatch(error) => error.fmt(f),
112 StreamError::Read(error)
113 | StreamError::Write(error)
114 | StreamError::Flush(error)
115 | StreamError::Shutdown(error) => error.fmt(f),
116 }
117 }
118}
119
120pub enum Invocation {
121 Future(Pin<Box<dyn Future<Output = Box<dyn Any + Send>> + Send>>),
122 Ready(Box<dyn Any + Send>),
123}
124
125#[wasmtime_wasi::async_trait]
126impl Pollable for Invocation {
127 async fn ready(&mut self) {
128 match self {
129 Self::Future(fut) => {
130 let res = fut.await;
131 *self = Self::Ready(res);
132 }
133 Self::Ready(..) => {}
134 }
135 }
136}
137
138pub struct OutgoingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
139
140pub struct IncomingChannel(pub Arc<std::sync::RwLock<Box<dyn Any + Send + Sync>>>);
141
142pub struct IncomingChannelStream<T> {
143 incoming: IncomingChannel,
144 _ty: PhantomData<T>,
145}
146
147impl<T: AsyncRead + Unpin + 'static> AsyncRead for IncomingChannelStream<T> {
148 fn poll_read(
149 self: Pin<&mut Self>,
150 cx: &mut Context<'_>,
151 buf: &mut ReadBuf<'_>,
152 ) -> Poll<std::io::Result<()>> {
153 let Ok(mut incoming) = self.incoming.0.write() else {
154 return Poll::Ready(Err(std::io::Error::new(
155 std::io::ErrorKind::Deadlock,
156 StreamError::LockPoisoned,
157 )));
158 };
159 let Some(incoming) = incoming.downcast_mut::<T>() else {
160 return Poll::Ready(Err(std::io::Error::new(
161 std::io::ErrorKind::InvalidData,
162 StreamError::TypeMismatch("invalid incoming channel type"),
163 )));
164 };
165 Pin::new(incoming)
166 .poll_read(cx, buf)
167 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Read(err)))
168 }
169}
170
171pub struct OutgoingChannelStream<T> {
172 outgoing: OutgoingChannel,
173 _ty: PhantomData<T>,
174}
175
176impl<T: AsyncWrite + Unpin + 'static> AsyncWrite for OutgoingChannelStream<T> {
177 fn poll_write(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<Result<usize, std::io::Error>> {
182 let Ok(mut outgoing) = self.outgoing.0.write() else {
183 return Poll::Ready(Err(std::io::Error::new(
184 std::io::ErrorKind::Deadlock,
185 StreamError::LockPoisoned,
186 )));
187 };
188 let Some(outgoing) = outgoing.downcast_mut::<T>() else {
189 return Poll::Ready(Err(std::io::Error::new(
190 std::io::ErrorKind::InvalidData,
191 StreamError::TypeMismatch("invalid outgoing channel type"),
192 )));
193 };
194 Pin::new(outgoing)
195 .poll_write(cx, buf)
196 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Write(err)))
197 }
198
199 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
200 let Ok(mut outgoing) = self.outgoing.0.write() else {
201 return Poll::Ready(Err(std::io::Error::new(
202 std::io::ErrorKind::Deadlock,
203 StreamError::LockPoisoned,
204 )));
205 };
206 let Some(outgoing) = outgoing.downcast_mut::<T>() else {
207 return Poll::Ready(Err(std::io::Error::new(
208 std::io::ErrorKind::InvalidData,
209 StreamError::TypeMismatch("invalid outgoing channel type"),
210 )));
211 };
212 Pin::new(outgoing)
213 .poll_flush(cx)
214 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Flush(err)))
215 }
216
217 fn poll_shutdown(
218 self: Pin<&mut Self>,
219 cx: &mut Context<'_>,
220 ) -> Poll<Result<(), std::io::Error>> {
221 let Ok(mut outgoing) = self.outgoing.0.write() else {
222 return Poll::Ready(Err(std::io::Error::new(
223 std::io::ErrorKind::Deadlock,
224 StreamError::LockPoisoned,
225 )));
226 };
227 let Some(outgoing) = outgoing.downcast_mut::<T>() else {
228 return Poll::Ready(Err(std::io::Error::new(
229 std::io::ErrorKind::InvalidData,
230 StreamError::TypeMismatch("invalid outgoing channel type"),
231 )));
232 };
233 Pin::new(outgoing)
234 .poll_shutdown(cx)
235 .map_err(|err| std::io::Error::new(err.kind(), StreamError::Shutdown(err)))
236 }
237}