wrpc_transport_nats/
lib.rs

1//! wRPC NATS.io transport
2
3#![allow(clippy::type_complexity)]
4
5use core::future::Future;
6use core::iter::zip;
7use core::ops::{Deref, DerefMut};
8use core::pin::{pin, Pin};
9use core::task::{ready, Context, Poll};
10use core::{mem, str};
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use anyhow::{anyhow, ensure, Context as _};
16use async_nats::message::OutboundMessage;
17use async_nats::{HeaderMap, ServerInfo, StatusCode, Subject};
18use bytes::{Buf as _, Bytes};
19use futures::sink::SinkExt as _;
20use futures::{Stream, StreamExt};
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tokio::select;
23use tokio::sync::{mpsc, oneshot};
24use tokio::task::JoinSet;
25use tokio_stream::wrappers::ReceiverStream;
26use tracing::{debug, error, instrument, trace, warn};
27use wrpc_transport::Index as _;
28
29pub const PROTOCOL: &str = "wrpc.0.0.1";
30
31fn spawn_async(fut: impl Future<Output = ()> + Send + 'static) {
32    match tokio::runtime::Handle::try_current() {
33        Ok(rt) => {
34            rt.spawn(fut);
35        }
36        Err(_) => match tokio::runtime::Runtime::new() {
37            Ok(rt) => {
38                rt.spawn(fut);
39            }
40            Err(err) => error!(?err, "failed to create a new Tokio runtime"),
41        },
42    }
43}
44
45fn new_inbox(inbox: &str) -> String {
46    let id = nuid::next();
47    let mut s = String::with_capacity(inbox.len().saturating_add(id.len()));
48    s.push_str(inbox);
49    s.push_str(&id);
50    s
51}
52
53#[must_use]
54#[inline]
55pub fn param_subject(prefix: &str) -> String {
56    format!("{prefix}.params")
57}
58
59#[must_use]
60#[inline]
61pub fn result_subject(prefix: &str) -> String {
62    format!("{prefix}.results")
63}
64
65#[must_use]
66#[inline]
67pub fn index_path(prefix: &str, path: &[usize]) -> String {
68    let mut s = String::with_capacity(prefix.len() + path.len() * 2);
69    if !prefix.is_empty() {
70        s.push_str(prefix);
71    }
72    for p in path {
73        if !s.is_empty() {
74            s.push('.');
75        }
76        s.push_str(&p.to_string());
77    }
78    s
79}
80
81#[must_use]
82#[inline]
83pub fn subscribe_path(prefix: &str, path: &[Option<usize>]) -> String {
84    let mut s = String::with_capacity(prefix.len() + path.len() * 2);
85    if !prefix.is_empty() {
86        s.push_str(prefix);
87    }
88    for p in path {
89        if !s.is_empty() {
90            s.push('.');
91        }
92        if let Some(p) = p {
93            s.push_str(&p.to_string());
94        } else {
95            s.push('*');
96        }
97    }
98    s
99}
100
101#[must_use]
102#[inline]
103pub fn invocation_subject(prefix: &str, instance: &str, func: &str) -> String {
104    let mut s =
105        String::with_capacity(prefix.len() + PROTOCOL.len() + instance.len() + func.len() + 3);
106    if !prefix.is_empty() {
107        s.push_str(prefix);
108        s.push('.');
109    }
110    s.push_str(PROTOCOL);
111    s.push('.');
112    if !instance.is_empty() {
113        s.push_str(instance);
114        s.push('.');
115    }
116    s.push_str(func);
117    s
118}
119
120fn corrupted_memory_error() -> std::io::Error {
121    std::io::Error::other("corrupted memory state")
122}
123
124/// Transport subscriber
125pub struct Subscriber {
126    rx: ReceiverStream<Message>,
127    subject: Subject,
128    commands: mpsc::Sender<Command>,
129    tasks: Arc<JoinSet<()>>,
130}
131
132impl Drop for Subscriber {
133    fn drop(&mut self) {
134        let commands = self.commands.clone();
135        let subject = mem::replace(&mut self.subject, Subject::from_static(""));
136        let tasks = Arc::clone(&self.tasks);
137        spawn_async(async move {
138            trace!(?subject, "shutting down subscriber");
139            if let Err(err) = commands.send(Command::Unsubscribe(subject)).await {
140                warn!(?err, "failed to shutdown subscriber");
141            }
142            drop(tasks);
143        });
144    }
145}
146
147impl Deref for Subscriber {
148    type Target = ReceiverStream<Message>;
149
150    fn deref(&self) -> &Self::Target {
151        &self.rx
152    }
153}
154
155impl DerefMut for Subscriber {
156    fn deref_mut(&mut self) -> &mut Self::Target {
157        &mut self.rx
158    }
159}
160
161enum Command {
162    Subscribe(Subject, mpsc::Sender<Message>),
163    Unsubscribe(Subject),
164    Batch(Box<[Command]>),
165}
166
167/// Subset of [`async_nats::Message`](async_nats::Message) used by this crate
168pub struct Message {
169    reply: Option<Subject>,
170    payload: Bytes,
171    status: Option<async_nats::StatusCode>,
172    description: Option<String>,
173}
174
175#[derive(Clone, Debug)]
176pub struct Client {
177    nats: Arc<async_nats::Client>,
178    prefix: Arc<str>,
179    inbox: Arc<str>,
180    queue_group: Option<Arc<str>>,
181    commands: mpsc::Sender<Command>,
182    tasks: Arc<JoinSet<()>>,
183}
184
185impl Client {
186    pub async fn new(
187        nats: impl Into<Arc<async_nats::Client>>,
188        prefix: impl Into<Arc<str>>,
189        queue_group: Option<Arc<str>>,
190    ) -> anyhow::Result<Self> {
191        let nats = nats.into();
192        let mut inbox = nats.new_inbox();
193        inbox.push('.');
194        let mut subject = String::with_capacity(inbox.len().saturating_add(1));
195        subject.push_str(&inbox);
196        subject.push('>');
197        let mut sub = nats
198            .subscribe(Subject::from(subject))
199            .await
200            .context("failed to subscribe on an inbox subject")?;
201
202        let mut tasks = JoinSet::new();
203        let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
204        tasks.spawn({
205            async move {
206                fn handle_command(subs: &mut HashMap<String, mpsc::Sender<Message>>, cmd: Command) {
207                    match cmd {
208                        Command::Subscribe(s, tx) => {
209                            subs.insert(s.into_string(), tx);
210                        }
211                        Command::Unsubscribe(s) => {
212                            subs.remove(s.as_str());
213                        }
214                        Command::Batch(cmds) => {
215                            for cmd in cmds {
216                                handle_command(subs, cmd);
217                            }
218                        }
219                    }
220                }
221                async fn handle_message(
222                    subs: &mut HashMap<String, mpsc::Sender<Message>>,
223                    async_nats::Message {
224                        subject,
225                        reply,
226                        payload,
227                        status,
228                        description,
229                        ..
230                    }: async_nats::Message,
231                ) {
232                    let Some(sub) = subs.get_mut(subject.as_str()) else {
233                        debug!(?subject, "drop message with no subscriber");
234                        return;
235                    };
236                    let Ok(sub) = sub.reserve().await else {
237                        debug!(?subject, "drop message with closed subscriber");
238                        subs.remove(subject.as_str());
239                        return;
240                    };
241                    sub.send(Message {
242                        reply,
243                        payload,
244                        status,
245                        description,
246                    });
247                }
248
249                let mut subs = HashMap::new();
250                loop {
251                    select! {
252                        Some(msg) = sub.next() => handle_message(&mut subs, msg).await,
253                        Some(cmd) = cmd_rx.recv() => handle_command(&mut subs, cmd),
254                        else => return,
255                    }
256                }
257            }
258        });
259        Ok(Self {
260            nats,
261            prefix: prefix.into(),
262            inbox: inbox.into(),
263            queue_group,
264            commands: cmd_tx,
265            tasks: Arc::new(tasks),
266        })
267    }
268}
269
270pub struct ByteSubscription(Subscriber);
271
272impl Stream for ByteSubscription {
273    type Item = std::io::Result<Bytes>;
274
275    #[instrument(level = "trace", skip_all)]
276    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277        match self.0.poll_next_unpin(cx) {
278            Poll::Ready(Some(Message { payload, .. })) => Poll::Ready(Some(Ok(payload))),
279            Poll::Ready(None) => Poll::Ready(None),
280            Poll::Pending => Poll::Pending,
281        }
282    }
283}
284
285#[derive(Default)]
286enum IndexTrie {
287    #[default]
288    Empty,
289    Leaf(Subscriber),
290    IndexNode {
291        subscriber: Option<Subscriber>,
292        nested: Vec<Option<IndexTrie>>,
293    },
294    WildcardNode {
295        subscriber: Option<Subscriber>,
296        nested: Option<Box<IndexTrie>>,
297    },
298}
299
300impl<'a> From<(&'a [Option<usize>], Subscriber)> for IndexTrie {
301    fn from((path, sub): (&'a [Option<usize>], Subscriber)) -> Self {
302        match path {
303            [] => Self::Leaf(sub),
304            [None, path @ ..] => Self::WildcardNode {
305                subscriber: None,
306                nested: Some(Box::new(Self::from((path, sub)))),
307            },
308            [Some(i), path @ ..] => Self::IndexNode {
309                subscriber: None,
310                nested: {
311                    let n = i.saturating_add(1);
312                    let mut nested = Vec::with_capacity(n);
313                    nested.resize_with(n, Option::default);
314                    nested[*i] = Some(Self::from((path, sub)));
315                    nested
316                },
317            },
318        }
319    }
320}
321
322impl<P: AsRef<[Option<usize>]>> FromIterator<(P, Subscriber)> for IndexTrie {
323    fn from_iter<T: IntoIterator<Item = (P, Subscriber)>>(iter: T) -> Self {
324        let mut root = Self::Empty;
325        for (path, sub) in iter {
326            if !root.insert(path.as_ref(), sub) {
327                return Self::Empty;
328            }
329        }
330        root
331    }
332}
333
334impl IndexTrie {
335    #[inline]
336    fn is_empty(&self) -> bool {
337        matches!(self, IndexTrie::Empty)
338    }
339
340    #[instrument(level = "trace", skip_all)]
341    fn take(&mut self, path: &[usize]) -> Option<Subscriber> {
342        let Some((i, path)) = path.split_first() else {
343            return match mem::take(self) {
344                // TODO: Demux the subscription
345                //IndexTrie::WildcardNode { subscriber, nested } => {
346                //    if let Some(nested) = nested {
347                //        *self = IndexTrie::WildcardNode {
348                //            subscriber: None,
349                //            nested: Some(nested),
350                //        }
351                //    }
352                //    subscriber
353                //}
354                IndexTrie::Empty | IndexTrie::WildcardNode { .. } => None,
355                IndexTrie::Leaf(subscriber) => Some(subscriber),
356                IndexTrie::IndexNode { subscriber, nested } => {
357                    if !nested.is_empty() {
358                        *self = IndexTrie::IndexNode {
359                            subscriber: None,
360                            nested,
361                        }
362                    }
363                    subscriber
364                }
365            };
366        };
367        match self {
368            // TODO: Demux the subscription
369            //Self::WildcardNode { ref mut nested, .. } => {
370            //    nested.as_mut().and_then(|nested| nested.take(path))
371            //}
372            Self::Empty | Self::Leaf(..) | Self::WildcardNode { .. } => None,
373            Self::IndexNode { ref mut nested, .. } => nested
374                .get_mut(*i)
375                .and_then(|nested| nested.as_mut().and_then(|nested| nested.take(path))),
376        }
377    }
378
379    /// Inserts `sub` under a `path` - returns `false` if it failed and `true` if it succeeded.
380    /// Tree state after `false` is returned in undefined
381    #[instrument(level = "trace", skip_all)]
382    fn insert(&mut self, path: &[Option<usize>], sub: Subscriber) -> bool {
383        match self {
384            Self::Empty => {
385                *self = Self::from((path, sub));
386                true
387            }
388            Self::Leaf(..) => {
389                let Some((i, path)) = path.split_first() else {
390                    return false;
391                };
392                let Self::Leaf(subscriber) = mem::take(self) else {
393                    return false;
394                };
395                if let Some(i) = i {
396                    let n = i.saturating_add(1);
397                    let mut nested = Vec::with_capacity(n);
398                    nested.resize_with(n, Option::default);
399                    nested[*i] = Some(Self::from((path, sub)));
400                    *self = Self::IndexNode {
401                        subscriber: Some(subscriber),
402                        nested,
403                    };
404                } else {
405                    *self = Self::WildcardNode {
406                        subscriber: Some(subscriber),
407                        nested: Some(Box::new(Self::from((path, sub)))),
408                    };
409                }
410                true
411            }
412            Self::WildcardNode {
413                ref mut subscriber,
414                ref mut nested,
415            } => match (&subscriber, path) {
416                (None, []) => {
417                    *subscriber = Some(sub);
418                    true
419                }
420                (_, [None, path @ ..]) => {
421                    if let Some(nested) = nested {
422                        nested.insert(path, sub)
423                    } else {
424                        *nested = Some(Box::new(Self::from((path, sub))));
425                        true
426                    }
427                }
428                _ => false,
429            },
430            Self::IndexNode {
431                ref mut subscriber,
432                ref mut nested,
433            } => match (&subscriber, path) {
434                (None, []) => {
435                    *subscriber = Some(sub);
436                    true
437                }
438                (_, [Some(i), path @ ..]) => {
439                    let cap = i.saturating_add(1);
440                    if nested.len() < cap {
441                        nested.resize_with(cap, Option::default);
442                    }
443                    let nested = &mut nested[*i];
444                    if let Some(nested) = nested {
445                        nested.insert(path, sub)
446                    } else {
447                        *nested = Some(Self::from((path, sub)));
448                        true
449                    }
450                }
451                _ => false,
452            },
453        }
454    }
455}
456
457pub struct Reader {
458    buffer: Bytes,
459    incoming: Option<Subscriber>,
460    nested: Arc<std::sync::Mutex<IndexTrie>>,
461    path: Box<[usize]>,
462}
463
464impl wrpc_transport::Index<Self> for Reader {
465    #[instrument(level = "trace", skip(self))]
466    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
467        ensure!(!path.is_empty());
468        trace!("locking index tree");
469        let mut nested = self
470            .nested
471            .lock()
472            .map_err(|err| anyhow!(err.to_string()).context("failed to lock map"))?;
473        trace!("taking index subscription");
474        let mut p = self.path.to_vec();
475        p.extend_from_slice(path);
476        let incoming = nested.take(&p);
477        Ok(Self {
478            buffer: Bytes::default(),
479            incoming,
480            nested: Arc::clone(&self.nested),
481            path: p.into_boxed_slice(),
482        })
483    }
484}
485
486impl AsyncRead for Reader {
487    #[instrument(level = "trace", skip_all, ret)]
488    fn poll_read(
489        mut self: Pin<&mut Self>,
490        cx: &mut Context<'_>,
491        buf: &mut ReadBuf<'_>,
492    ) -> Poll<std::io::Result<()>> {
493        let cap = buf.remaining();
494        if cap == 0 {
495            trace!("attempt to read empty buffer");
496            return Poll::Ready(Ok(()));
497        }
498
499        if !self.buffer.is_empty() {
500            if self.buffer.len() > cap {
501                trace!(cap, len = self.buffer.len(), "reading part of buffer");
502                buf.put_slice(&self.buffer.split_to(cap));
503            } else {
504                trace!(cap, len = self.buffer.len(), "reading full buffer");
505                buf.put_slice(&mem::take(&mut self.buffer));
506            }
507            return Poll::Ready(Ok(()));
508        }
509        let Some(incoming) = self.incoming.as_mut() else {
510            return Poll::Ready(Err(std::io::Error::new(
511                std::io::ErrorKind::NotFound,
512                format!("subscription not found for path {:?}", self.path),
513            )));
514        };
515        trace!("polling for next message");
516        match incoming.poll_next_unpin(cx) {
517            Poll::Ready(Some(Message { mut payload, .. })) => {
518                trace!(?payload, "received message");
519                if payload.is_empty() {
520                    trace!("received stream shutdown message");
521                    return Poll::Ready(Ok(()));
522                }
523                if payload.len() > cap {
524                    trace!(len = payload.len(), cap, "partially reading the message");
525                    buf.put_slice(&payload.split_to(cap));
526                    self.buffer = payload;
527                } else {
528                    trace!(len = payload.len(), cap, "filling the buffer with payload");
529                    buf.put_slice(&payload);
530                }
531                Poll::Ready(Ok(()))
532            }
533            Poll::Ready(None) => {
534                trace!("subscription finished");
535                Poll::Ready(Ok(()))
536            }
537            Poll::Pending => Poll::Pending,
538        }
539    }
540}
541
542#[derive(Clone, Debug)]
543pub struct SubjectWriter {
544    nats: async_nats::Client,
545    tx: Subject,
546    shutdown: bool,
547    tasks: Arc<JoinSet<()>>,
548}
549
550impl SubjectWriter {
551    fn new(nats: async_nats::Client, tx: Subject, tasks: Arc<JoinSet<()>>) -> Self {
552        Self {
553            nats,
554            tx,
555            shutdown: false,
556            tasks,
557        }
558    }
559}
560
561impl wrpc_transport::Index<Self> for SubjectWriter {
562    #[instrument(level = "trace", skip(self))]
563    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
564        ensure!(!path.is_empty());
565        let tx = Subject::from(index_path(self.tx.as_str(), path));
566        Ok(Self {
567            nats: self.nats.clone(),
568            tx,
569            shutdown: false,
570            tasks: Arc::clone(&self.tasks),
571        })
572    }
573}
574
575impl AsyncWrite for SubjectWriter {
576    #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str(), buf = format!("{buf:02x?}")))]
577    fn poll_write(
578        mut self: Pin<&mut Self>,
579        cx: &mut Context<'_>,
580        mut buf: &[u8],
581    ) -> Poll<std::io::Result<usize>> {
582        trace!("polling for readiness");
583        match self.nats.poll_ready_unpin(cx) {
584            Poll::Pending => return Poll::Pending,
585            Poll::Ready(Err(err)) => {
586                return Poll::Ready(Err(std::io::Error::new(
587                    std::io::ErrorKind::BrokenPipe,
588                    err,
589                )))
590            }
591            Poll::Ready(Ok(())) => {}
592        }
593        let ServerInfo { max_payload, .. } = self.nats.server_info();
594        if max_payload == 0 {
595            return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
596        }
597        if buf.len() > max_payload {
598            (buf, _) = buf.split_at(max_payload);
599        }
600        trace!("starting send");
601        let subject = self.tx.clone();
602        match self.nats.start_send_unpin(OutboundMessage {
603            subject,
604            payload: Bytes::copy_from_slice(buf),
605            reply: None,
606            headers: None,
607        }) {
608            Ok(()) => Poll::Ready(Ok(buf.len())),
609            Err(err) => Poll::Ready(Err(std::io::Error::new(
610                std::io::ErrorKind::BrokenPipe,
611                err,
612            ))),
613        }
614    }
615
616    #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
617    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
618        trace!("flushing");
619        self.nats
620            .poll_flush_unpin(cx)
621            .map_err(|_| std::io::ErrorKind::BrokenPipe.into())
622    }
623
624    #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
625    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
626        trace!("writing stream shutdown message");
627        ready!(self.as_mut().poll_write(cx, &[]))?;
628        self.shutdown = true;
629        Poll::Ready(Ok(()))
630    }
631}
632
633impl Drop for SubjectWriter {
634    fn drop(&mut self) {
635        if !self.shutdown {
636            let nats = self.nats.clone();
637            let subject = mem::replace(&mut self.tx, Subject::from_static(""));
638            let tasks = Arc::clone(&self.tasks);
639            spawn_async(async move {
640                trace!("writing stream shutdown message");
641                if let Err(err) = nats.publish(subject, Bytes::default()).await {
642                    warn!(?err, "failed to publish stream shutdown message");
643                }
644                drop(tasks);
645            });
646        }
647    }
648}
649
650#[derive(Default)]
651pub enum RootParamWriter {
652    #[default]
653    Corrupted,
654    Handshaking {
655        nats: async_nats::Client,
656        sub: Subscriber,
657        indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
658        buffer: Bytes,
659        tasks: Arc<JoinSet<()>>,
660    },
661    Draining {
662        tx: SubjectWriter,
663        buffer: Bytes,
664    },
665    Active(SubjectWriter),
666}
667
668impl RootParamWriter {
669    fn new(
670        nats: async_nats::Client,
671        sub: Subscriber,
672        buffer: Bytes,
673        tasks: Arc<JoinSet<()>>,
674    ) -> Self {
675        Self::Handshaking {
676            nats,
677            sub,
678            indexed: std::sync::Mutex::default(),
679            buffer,
680            tasks,
681        }
682    }
683}
684
685impl RootParamWriter {
686    #[instrument(level = "trace", skip_all, ret)]
687    fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
688        match &mut *self {
689            Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
690            Self::Handshaking { sub, .. } => {
691                trace!("polling for handshake response");
692                match sub.poll_next_unpin(cx) {
693                    Poll::Ready(Some(Message {
694                        status: Some(StatusCode::NO_RESPONDERS),
695                        ..
696                    })) => Poll::Ready(Err(std::io::ErrorKind::NotConnected.into())),
697                    Poll::Ready(Some(Message {
698                        status: Some(StatusCode::TIMEOUT),
699                        ..
700                    })) => Poll::Ready(Err(std::io::ErrorKind::TimedOut.into())),
701                    Poll::Ready(Some(Message {
702                        status: Some(StatusCode::REQUEST_TERMINATED),
703                        ..
704                    })) => Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
705                    Poll::Ready(Some(Message {
706                        status: Some(code),
707                        description,
708                        ..
709                    })) if !code.is_success() => Poll::Ready(Err(std::io::Error::other(
710                        if let Some(description) = description {
711                            format!("received a response with code `{code}` ({description})")
712                        } else {
713                            format!("received a response with code `{code}`")
714                        },
715                    ))),
716                    Poll::Ready(Some(Message {
717                        reply: Some(tx), ..
718                    })) => {
719                        let Self::Handshaking {
720                            nats,
721                            indexed,
722                            buffer,
723                            tasks,
724                            ..
725                        } = mem::take(&mut *self)
726                        else {
727                            return Poll::Ready(Err(corrupted_memory_error()));
728                        };
729                        let tx = SubjectWriter::new(nats, Subject::from(param_subject(&tx)), tasks);
730                        let indexed = indexed
731                            .into_inner()
732                            .map_err(|err| std::io::Error::other(err.to_string()))?;
733                        for (path, tx_tx) in indexed {
734                            let tx = tx.index(&path).map_err(std::io::Error::other)?;
735                            tx_tx.send(tx).map_err(|_| {
736                                std::io::Error::from(std::io::ErrorKind::BrokenPipe)
737                            })?;
738                        }
739                        trace!("handshake succeeded");
740                        if buffer.is_empty() {
741                            *self = Self::Active(tx);
742                            Poll::Ready(Ok(()))
743                        } else {
744                            *self = Self::Draining { tx, buffer };
745                            self.poll_active(cx)
746                        }
747                    }
748                    Poll::Ready(Some(..)) => Poll::Ready(Err(std::io::Error::new(
749                        std::io::ErrorKind::InvalidInput,
750                        "peer did not specify a reply subject",
751                    ))),
752                    Poll::Ready(None) => {
753                        *self = Self::Corrupted;
754                        Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))
755                    }
756                    Poll::Pending => Poll::Pending,
757                }
758            }
759            Self::Draining { tx, buffer } => {
760                let mut tx = pin!(tx);
761                while !buffer.is_empty() {
762                    trace!(?tx.tx, "draining parameter buffer");
763                    match tx.as_mut().poll_write(cx, buffer) {
764                        Poll::Ready(Ok(n)) => {
765                            buffer.advance(n);
766                        }
767                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
768                        Poll::Pending => return Poll::Pending,
769                    }
770                }
771                let Self::Draining { tx, .. } = mem::take(&mut *self) else {
772                    return Poll::Ready(Err(corrupted_memory_error()));
773                };
774                trace!("parameter buffer draining succeeded");
775                *self = Self::Active(tx);
776                Poll::Ready(Ok(()))
777            }
778            Self::Active(..) => Poll::Ready(Ok(())),
779        }
780    }
781}
782
783impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
784    #[instrument(level = "trace", skip(self))]
785    fn index(&self, path: &[usize]) -> anyhow::Result<IndexedParamWriter> {
786        ensure!(!path.is_empty());
787        match self {
788            Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
789            Self::Handshaking { indexed, .. } => {
790                let (tx_tx, tx_rx) = oneshot::channel();
791                let mut indexed = indexed
792                    .lock()
793                    .map_err(|err| std::io::Error::other(err.to_string()))?;
794                indexed.push((path.to_vec(), tx_tx));
795                Ok(IndexedParamWriter::Handshaking {
796                    tx_rx,
797                    indexed: std::sync::Mutex::default(),
798                })
799            }
800            Self::Draining { tx, .. } | Self::Active(tx) => {
801                tx.index(path).map(IndexedParamWriter::Active)
802            }
803        }
804    }
805}
806
807impl AsyncWrite for RootParamWriter {
808    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
809    fn poll_write(
810        mut self: Pin<&mut Self>,
811        cx: &mut Context<'_>,
812        buf: &[u8],
813    ) -> Poll<std::io::Result<usize>> {
814        match self.as_mut().poll_active(cx)? {
815            Poll::Ready(()) => {
816                let Self::Active(tx) = &mut *self else {
817                    return Poll::Ready(Err(corrupted_memory_error()));
818                };
819                trace!("writing buffer");
820                pin!(tx).poll_write(cx, buf)
821            }
822            Poll::Pending => Poll::Pending,
823        }
824    }
825
826    #[instrument(level = "trace", skip_all, ret)]
827    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
828        match self.as_mut().poll_active(cx)? {
829            Poll::Ready(()) => {
830                let Self::Active(tx) = &mut *self else {
831                    return Poll::Ready(Err(corrupted_memory_error()));
832                };
833                trace!("flushing");
834                pin!(tx).poll_flush(cx)
835            }
836            Poll::Pending => Poll::Pending,
837        }
838    }
839
840    #[instrument(level = "trace", skip_all, ret)]
841    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
842        match self.as_mut().poll_active(cx)? {
843            Poll::Ready(()) => {
844                let Self::Active(tx) = &mut *self else {
845                    return Poll::Ready(Err(corrupted_memory_error()));
846                };
847                trace!("shutting down");
848                pin!(tx).poll_shutdown(cx)
849            }
850            Poll::Pending => Poll::Pending,
851        }
852    }
853}
854
855#[derive(Debug, Default)]
856pub enum IndexedParamWriter {
857    #[default]
858    Corrupted,
859    Handshaking {
860        tx_rx: oneshot::Receiver<SubjectWriter>,
861        indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
862    },
863    Active(SubjectWriter),
864}
865
866impl IndexedParamWriter {
867    #[instrument(level = "trace", skip_all, ret)]
868    fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
869        match &mut *self {
870            Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
871            Self::Handshaking { tx_rx, .. } => {
872                trace!("polling for handshake");
873                match pin!(tx_rx).poll(cx) {
874                    Poll::Ready(Ok(tx)) => {
875                        let Self::Handshaking { indexed, .. } = mem::take(&mut *self) else {
876                            return Poll::Ready(Err(corrupted_memory_error()));
877                        };
878                        let indexed = indexed
879                            .into_inner()
880                            .map_err(|err| std::io::Error::other(err.to_string()))?;
881                        for (path, tx_tx) in indexed {
882                            let tx = tx.index(&path).map_err(std::io::Error::other)?;
883                            tx_tx.send(tx).map_err(|_| {
884                                std::io::Error::from(std::io::ErrorKind::BrokenPipe)
885                            })?;
886                        }
887                        *self = Self::Active(tx);
888                        Poll::Ready(Ok(()))
889                    }
890                    Poll::Ready(Err(..)) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
891                    Poll::Pending => Poll::Pending,
892                }
893            }
894            Self::Active(..) => Poll::Ready(Ok(())),
895        }
896    }
897}
898
899impl wrpc_transport::Index<Self> for IndexedParamWriter {
900    #[instrument(level = "trace", skip_all)]
901    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
902        ensure!(!path.is_empty());
903        match self {
904            Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
905            Self::Handshaking { indexed, .. } => {
906                let (tx_tx, tx_rx) = oneshot::channel();
907                let mut indexed = indexed
908                    .lock()
909                    .map_err(|err| std::io::Error::other(err.to_string()))?;
910                indexed.push((path.to_vec(), tx_tx));
911                Ok(Self::Handshaking {
912                    tx_rx,
913                    indexed: std::sync::Mutex::default(),
914                })
915            }
916            Self::Active(tx) => tx.index(path).map(Self::Active),
917        }
918    }
919}
920
921impl AsyncWrite for IndexedParamWriter {
922    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
923    fn poll_write(
924        mut self: Pin<&mut Self>,
925        cx: &mut Context<'_>,
926        buf: &[u8],
927    ) -> Poll<std::io::Result<usize>> {
928        match self.as_mut().poll_active(cx)? {
929            Poll::Ready(()) => {
930                let Self::Active(tx) = &mut *self else {
931                    return Poll::Ready(Err(corrupted_memory_error()));
932                };
933                trace!("writing buffer");
934                pin!(tx).poll_write(cx, buf)
935            }
936            Poll::Pending => Poll::Pending,
937        }
938    }
939
940    #[instrument(level = "trace", skip_all, ret)]
941    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
942        match self.as_mut().poll_active(cx)? {
943            Poll::Ready(()) => {
944                let Self::Active(tx) = &mut *self else {
945                    return Poll::Ready(Err(corrupted_memory_error()));
946                };
947                trace!("flushing");
948                pin!(tx).poll_flush(cx)
949            }
950            Poll::Pending => Poll::Pending,
951        }
952    }
953
954    #[instrument(level = "trace", skip_all, ret)]
955    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
956        match self.as_mut().poll_active(cx)? {
957            Poll::Ready(()) => {
958                let Self::Active(tx) = &mut *self else {
959                    return Poll::Ready(Err(corrupted_memory_error()));
960                };
961                trace!("shutting down");
962                pin!(tx).poll_shutdown(cx)
963            }
964            Poll::Pending => Poll::Pending,
965        }
966    }
967}
968
969pub enum ParamWriter {
970    Root(RootParamWriter),
971    Nested(IndexedParamWriter),
972}
973
974impl wrpc_transport::Index<Self> for ParamWriter {
975    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
976        ensure!(!path.is_empty());
977        match self {
978            ParamWriter::Root(w) => w.index(path),
979            ParamWriter::Nested(w) => w.index(path),
980        }
981        .map(Self::Nested)
982    }
983}
984
985impl AsyncWrite for ParamWriter {
986    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
987    fn poll_write(
988        mut self: Pin<&mut Self>,
989        cx: &mut Context<'_>,
990        buf: &[u8],
991    ) -> Poll<std::io::Result<usize>> {
992        match &mut *self {
993            ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
994            ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
995        }
996    }
997
998    #[instrument(level = "trace", skip_all, ret)]
999    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1000        match &mut *self {
1001            ParamWriter::Root(w) => pin!(w).poll_flush(cx),
1002            ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
1003        }
1004    }
1005
1006    #[instrument(level = "trace", skip_all, ret)]
1007    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1008        match &mut *self {
1009            ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
1010            ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
1011        }
1012    }
1013}
1014
1015impl wrpc_transport::Invoke for Client {
1016    type Context = Option<HeaderMap>;
1017    type Outgoing = ParamWriter;
1018    type Incoming = Reader;
1019
1020    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
1021    async fn invoke<P: AsRef<[Option<usize>]> + Send + Sync>(
1022        &self,
1023        cx: Self::Context,
1024        instance: &str,
1025        func: &str,
1026        mut params: Bytes,
1027        paths: impl AsRef<[P]> + Send,
1028    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)> {
1029        let paths = paths.as_ref();
1030        let mut cmds = Vec::with_capacity(paths.len().saturating_add(2));
1031
1032        let rx = Subject::from(new_inbox(&self.inbox));
1033        let (handshake_tx, handshake_rx) = mpsc::channel(1);
1034        cmds.push(Command::Subscribe(rx.clone(), handshake_tx));
1035
1036        let result = Subject::from(result_subject(&rx));
1037        let (result_tx, result_rx) = mpsc::channel(16);
1038        cmds.push(Command::Subscribe(result.clone(), result_tx));
1039
1040        let nested = paths.iter().map(|path| {
1041            let (tx, rx) = mpsc::channel(16);
1042            let subject = Subject::from(subscribe_path(&result, path.as_ref()));
1043            cmds.push(Command::Subscribe(subject.clone(), tx));
1044            Subscriber {
1045                rx: ReceiverStream::new(rx),
1046                commands: self.commands.clone(),
1047                subject,
1048                tasks: Arc::clone(&self.tasks),
1049            }
1050        });
1051        let nested: IndexTrie = zip(paths.iter(), nested).collect();
1052        ensure!(
1053            paths.is_empty() == nested.is_empty(),
1054            "failed to construct subscription tree"
1055        );
1056
1057        self.commands
1058            .send(Command::Batch(cmds.into_boxed_slice()))
1059            .await
1060            .context("failed to subscribe")?;
1061
1062        let ServerInfo {
1063            mut max_payload, ..
1064        } = self.nats.server_info();
1065        max_payload = max_payload.saturating_sub(rx.len());
1066        let param_tx = Subject::from(invocation_subject(&self.prefix, instance, func));
1067        if let Some(headers) = cx {
1068            // based on https://github.com/nats-io/nats.rs/blob/0942c473ce56163fdd1fbc62762f8164e3afa7bf/async-nats/src/header.rs#L215-L224
1069            max_payload = max_payload
1070                .saturating_sub(b"NATS/1.0\r\n".len())
1071                .saturating_sub(b"\r\n".len());
1072            for (k, vs) in headers.iter() {
1073                let k: &[u8] = k.as_ref();
1074                for v in vs {
1075                    max_payload = max_payload
1076                        .saturating_sub(k.len())
1077                        .saturating_sub(b": ".len())
1078                        .saturating_sub(v.as_str().len())
1079                        .saturating_sub(b"\r\n".len());
1080                }
1081            }
1082            trace!("publishing handshake");
1083            self.nats
1084                .publish_with_reply_and_headers(
1085                    param_tx,
1086                    rx.clone(),
1087                    headers,
1088                    params.split_to(max_payload.min(params.len())),
1089                )
1090                .await
1091        } else {
1092            trace!("publishing handshake");
1093            self.nats
1094                .publish_with_reply(
1095                    param_tx,
1096                    rx.clone(),
1097                    params.split_to(max_payload.min(params.len())),
1098                )
1099                .await
1100        }
1101        .context("failed to publish handshake")?;
1102        let nats = Arc::clone(&self.nats);
1103        tokio::spawn(async move {
1104            if let Err(err) = nats.flush().await {
1105                error!(?err, "failed to flush");
1106            }
1107        });
1108        Ok((
1109            ParamWriter::Root(RootParamWriter::new(
1110                (*self.nats).clone(),
1111                Subscriber {
1112                    rx: ReceiverStream::new(handshake_rx),
1113                    commands: self.commands.clone(),
1114                    subject: rx,
1115                    tasks: Arc::clone(&self.tasks),
1116                },
1117                params,
1118                Arc::clone(&self.tasks),
1119            )),
1120            Reader {
1121                buffer: Bytes::default(),
1122                incoming: Some(Subscriber {
1123                    rx: ReceiverStream::new(result_rx),
1124                    commands: self.commands.clone(),
1125                    subject: result,
1126                    tasks: Arc::clone(&self.tasks),
1127                }),
1128                nested: Arc::new(std::sync::Mutex::new(nested)),
1129                path: Box::default(),
1130            },
1131        ))
1132    }
1133}
1134
1135async fn handle_message(
1136    nats: &async_nats::Client,
1137    rx: Subject,
1138    commands: mpsc::Sender<Command>,
1139    async_nats::Message {
1140        reply: tx,
1141        payload,
1142        headers,
1143        ..
1144    }: async_nats::Message,
1145    paths: &[Box<[Option<usize>]>],
1146    tasks: Arc<JoinSet<()>>,
1147) -> anyhow::Result<(Option<HeaderMap>, SubjectWriter, Reader)> {
1148    let tx = tx.context("peer did not specify a reply subject")?;
1149
1150    let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
1151
1152    let param = Subject::from(param_subject(&rx));
1153    let (param_tx, param_rx) = mpsc::channel(16);
1154    cmds.push(Command::Subscribe(param.clone(), param_tx));
1155
1156    let nested = paths.iter().map(|path| {
1157        let (tx, rx) = mpsc::channel(16);
1158        let subject = Subject::from(subscribe_path(&param, path.as_ref()));
1159        cmds.push(Command::Subscribe(subject.clone(), tx));
1160        Subscriber {
1161            rx: ReceiverStream::new(rx),
1162            commands: commands.clone(),
1163            subject,
1164            tasks: Arc::clone(&tasks),
1165        }
1166    });
1167    let nested: IndexTrie = zip(paths.iter(), nested).collect();
1168    ensure!(
1169        paths.is_empty() == nested.is_empty(),
1170        "failed to construct subscription tree"
1171    );
1172
1173    commands
1174        .send(Command::Batch(cmds.into_boxed_slice()))
1175        .await
1176        .context("failed to subscribe")?;
1177
1178    trace!("publishing handshake response");
1179    nats.publish_with_reply(tx.clone(), rx, Bytes::default())
1180        .await
1181        .context("failed to publish handshake accept")?;
1182    Ok((
1183        headers,
1184        SubjectWriter::new(
1185            nats.clone(),
1186            Subject::from(result_subject(&tx)),
1187            Arc::clone(&tasks),
1188        ),
1189        Reader {
1190            buffer: payload,
1191            incoming: Some(Subscriber {
1192                rx: ReceiverStream::new(param_rx),
1193                commands,
1194                subject: param,
1195                tasks,
1196            }),
1197            nested: Arc::new(std::sync::Mutex::new(nested)),
1198            path: Box::default(),
1199        },
1200    ))
1201}
1202
1203impl wrpc_transport::Serve for Client {
1204    type Context = Option<HeaderMap>;
1205    type Outgoing = SubjectWriter;
1206    type Incoming = Reader;
1207
1208    #[instrument(level = "trace", skip(self, paths))]
1209    async fn serve(
1210        &self,
1211        instance: &str,
1212        func: &str,
1213        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1214    ) -> anyhow::Result<
1215        impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
1216    > {
1217        let subject = invocation_subject(&self.prefix, instance, func);
1218        let sub = if let Some(group) = &self.queue_group {
1219            debug!(subject, ?group, "queue-subscribing on invocation subject");
1220            self.nats
1221                .queue_subscribe(subject, group.to_string())
1222                .await?
1223        } else {
1224            debug!(subject, "subscribing on invocation subject");
1225            self.nats.subscribe(subject).await?
1226        };
1227        let nats = Arc::clone(&self.nats);
1228        let paths = paths.into();
1229        let commands = self.commands.clone();
1230        let inbox = Arc::clone(&self.inbox);
1231        let tasks = Arc::clone(&self.tasks);
1232        Ok(sub.then(move |msg| {
1233            let tasks = Arc::clone(&tasks);
1234            let nats = Arc::clone(&nats);
1235            let paths = Arc::clone(&paths);
1236            let commands = commands.clone();
1237            let rx = Subject::from(new_inbox(&inbox));
1238            async move { handle_message(&nats, rx, commands, msg, &paths, tasks).await }
1239        }))
1240    }
1241}