1#![allow(clippy::type_complexity)]
4
5use core::future::Future;
6use core::iter::zip;
7use core::ops::{Deref, DerefMut};
8use core::pin::{pin, Pin};
9use core::task::{ready, Context, Poll};
10use core::{mem, str};
11
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use anyhow::{anyhow, ensure, Context as _};
16use async_nats::message::OutboundMessage;
17use async_nats::{HeaderMap, ServerInfo, StatusCode, Subject};
18use bytes::{Buf as _, Bytes};
19use futures::sink::SinkExt as _;
20use futures::{Stream, StreamExt};
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tokio::select;
23use tokio::sync::{mpsc, oneshot};
24use tokio::task::JoinSet;
25use tokio_stream::wrappers::ReceiverStream;
26use tracing::{debug, error, instrument, trace, warn};
27use wrpc_transport::Index as _;
28
29pub const PROTOCOL: &str = "wrpc.0.0.1";
30
31fn spawn_async(fut: impl Future<Output = ()> + Send + 'static) {
32 match tokio::runtime::Handle::try_current() {
33 Ok(rt) => {
34 rt.spawn(fut);
35 }
36 Err(_) => match tokio::runtime::Runtime::new() {
37 Ok(rt) => {
38 rt.spawn(fut);
39 }
40 Err(err) => error!(?err, "failed to create a new Tokio runtime"),
41 },
42 }
43}
44
45fn new_inbox(inbox: &str) -> String {
46 let id = nuid::next();
47 let mut s = String::with_capacity(inbox.len().saturating_add(id.len()));
48 s.push_str(inbox);
49 s.push_str(&id);
50 s
51}
52
53#[must_use]
54#[inline]
55pub fn param_subject(prefix: &str) -> String {
56 format!("{prefix}.params")
57}
58
59#[must_use]
60#[inline]
61pub fn result_subject(prefix: &str) -> String {
62 format!("{prefix}.results")
63}
64
65#[must_use]
66#[inline]
67pub fn index_path(prefix: &str, path: &[usize]) -> String {
68 let mut s = String::with_capacity(prefix.len() + path.len() * 2);
69 if !prefix.is_empty() {
70 s.push_str(prefix);
71 }
72 for p in path {
73 if !s.is_empty() {
74 s.push('.');
75 }
76 s.push_str(&p.to_string());
77 }
78 s
79}
80
81#[must_use]
82#[inline]
83pub fn subscribe_path(prefix: &str, path: &[Option<usize>]) -> String {
84 let mut s = String::with_capacity(prefix.len() + path.len() * 2);
85 if !prefix.is_empty() {
86 s.push_str(prefix);
87 }
88 for p in path {
89 if !s.is_empty() {
90 s.push('.');
91 }
92 if let Some(p) = p {
93 s.push_str(&p.to_string());
94 } else {
95 s.push('*');
96 }
97 }
98 s
99}
100
101#[must_use]
102#[inline]
103pub fn invocation_subject(prefix: &str, instance: &str, func: &str) -> String {
104 let mut s =
105 String::with_capacity(prefix.len() + PROTOCOL.len() + instance.len() + func.len() + 3);
106 if !prefix.is_empty() {
107 s.push_str(prefix);
108 s.push('.');
109 }
110 s.push_str(PROTOCOL);
111 s.push('.');
112 if !instance.is_empty() {
113 s.push_str(instance);
114 s.push('.');
115 }
116 s.push_str(func);
117 s
118}
119
120fn corrupted_memory_error() -> std::io::Error {
121 std::io::Error::other("corrupted memory state")
122}
123
124pub struct Subscriber {
126 rx: ReceiverStream<Message>,
127 subject: Subject,
128 commands: mpsc::Sender<Command>,
129 tasks: Arc<JoinSet<()>>,
130}
131
132impl Drop for Subscriber {
133 fn drop(&mut self) {
134 let commands = self.commands.clone();
135 let subject = mem::replace(&mut self.subject, Subject::from_static(""));
136 let tasks = Arc::clone(&self.tasks);
137 spawn_async(async move {
138 trace!(?subject, "shutting down subscriber");
139 if let Err(err) = commands.send(Command::Unsubscribe(subject)).await {
140 warn!(?err, "failed to shutdown subscriber");
141 }
142 drop(tasks);
143 });
144 }
145}
146
147impl Deref for Subscriber {
148 type Target = ReceiverStream<Message>;
149
150 fn deref(&self) -> &Self::Target {
151 &self.rx
152 }
153}
154
155impl DerefMut for Subscriber {
156 fn deref_mut(&mut self) -> &mut Self::Target {
157 &mut self.rx
158 }
159}
160
161enum Command {
162 Subscribe(Subject, mpsc::Sender<Message>),
163 Unsubscribe(Subject),
164 Batch(Box<[Command]>),
165}
166
167pub struct Message {
169 reply: Option<Subject>,
170 payload: Bytes,
171 status: Option<async_nats::StatusCode>,
172 description: Option<String>,
173}
174
175#[derive(Clone, Debug)]
176pub struct Client {
177 nats: Arc<async_nats::Client>,
178 prefix: Arc<str>,
179 inbox: Arc<str>,
180 queue_group: Option<Arc<str>>,
181 commands: mpsc::Sender<Command>,
182 tasks: Arc<JoinSet<()>>,
183}
184
185impl Client {
186 pub async fn new(
187 nats: impl Into<Arc<async_nats::Client>>,
188 prefix: impl Into<Arc<str>>,
189 queue_group: Option<Arc<str>>,
190 ) -> anyhow::Result<Self> {
191 let nats = nats.into();
192 let mut inbox = nats.new_inbox();
193 inbox.push('.');
194 let mut subject = String::with_capacity(inbox.len().saturating_add(1));
195 subject.push_str(&inbox);
196 subject.push('>');
197 let mut sub = nats
198 .subscribe(Subject::from(subject))
199 .await
200 .context("failed to subscribe on an inbox subject")?;
201
202 let mut tasks = JoinSet::new();
203 let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
204 tasks.spawn({
205 async move {
206 fn handle_command(subs: &mut HashMap<String, mpsc::Sender<Message>>, cmd: Command) {
207 match cmd {
208 Command::Subscribe(s, tx) => {
209 subs.insert(s.into_string(), tx);
210 }
211 Command::Unsubscribe(s) => {
212 subs.remove(s.as_str());
213 }
214 Command::Batch(cmds) => {
215 for cmd in cmds {
216 handle_command(subs, cmd);
217 }
218 }
219 }
220 }
221 async fn handle_message(
222 subs: &mut HashMap<String, mpsc::Sender<Message>>,
223 async_nats::Message {
224 subject,
225 reply,
226 payload,
227 status,
228 description,
229 ..
230 }: async_nats::Message,
231 ) {
232 let Some(sub) = subs.get_mut(subject.as_str()) else {
233 debug!(?subject, "drop message with no subscriber");
234 return;
235 };
236 let Ok(sub) = sub.reserve().await else {
237 debug!(?subject, "drop message with closed subscriber");
238 subs.remove(subject.as_str());
239 return;
240 };
241 sub.send(Message {
242 reply,
243 payload,
244 status,
245 description,
246 });
247 }
248
249 let mut subs = HashMap::new();
250 loop {
251 select! {
252 Some(msg) = sub.next() => handle_message(&mut subs, msg).await,
253 Some(cmd) = cmd_rx.recv() => handle_command(&mut subs, cmd),
254 else => return,
255 }
256 }
257 }
258 });
259 Ok(Self {
260 nats,
261 prefix: prefix.into(),
262 inbox: inbox.into(),
263 queue_group,
264 commands: cmd_tx,
265 tasks: Arc::new(tasks),
266 })
267 }
268}
269
270pub struct ByteSubscription(Subscriber);
271
272impl Stream for ByteSubscription {
273 type Item = std::io::Result<Bytes>;
274
275 #[instrument(level = "trace", skip_all)]
276 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277 match self.0.poll_next_unpin(cx) {
278 Poll::Ready(Some(Message { payload, .. })) => Poll::Ready(Some(Ok(payload))),
279 Poll::Ready(None) => Poll::Ready(None),
280 Poll::Pending => Poll::Pending,
281 }
282 }
283}
284
285#[derive(Default)]
286enum IndexTrie {
287 #[default]
288 Empty,
289 Leaf(Subscriber),
290 IndexNode {
291 subscriber: Option<Subscriber>,
292 nested: Vec<Option<IndexTrie>>,
293 },
294 WildcardNode {
295 subscriber: Option<Subscriber>,
296 nested: Option<Box<IndexTrie>>,
297 },
298}
299
300impl<'a> From<(&'a [Option<usize>], Subscriber)> for IndexTrie {
301 fn from((path, sub): (&'a [Option<usize>], Subscriber)) -> Self {
302 match path {
303 [] => Self::Leaf(sub),
304 [None, path @ ..] => Self::WildcardNode {
305 subscriber: None,
306 nested: Some(Box::new(Self::from((path, sub)))),
307 },
308 [Some(i), path @ ..] => Self::IndexNode {
309 subscriber: None,
310 nested: {
311 let n = i.saturating_add(1);
312 let mut nested = Vec::with_capacity(n);
313 nested.resize_with(n, Option::default);
314 nested[*i] = Some(Self::from((path, sub)));
315 nested
316 },
317 },
318 }
319 }
320}
321
322impl<P: AsRef<[Option<usize>]>> FromIterator<(P, Subscriber)> for IndexTrie {
323 fn from_iter<T: IntoIterator<Item = (P, Subscriber)>>(iter: T) -> Self {
324 let mut root = Self::Empty;
325 for (path, sub) in iter {
326 if !root.insert(path.as_ref(), sub) {
327 return Self::Empty;
328 }
329 }
330 root
331 }
332}
333
334impl IndexTrie {
335 #[inline]
336 fn is_empty(&self) -> bool {
337 matches!(self, IndexTrie::Empty)
338 }
339
340 #[instrument(level = "trace", skip_all)]
341 fn take(&mut self, path: &[usize]) -> Option<Subscriber> {
342 let Some((i, path)) = path.split_first() else {
343 return match mem::take(self) {
344 IndexTrie::Empty | IndexTrie::WildcardNode { .. } => None,
355 IndexTrie::Leaf(subscriber) => Some(subscriber),
356 IndexTrie::IndexNode { subscriber, nested } => {
357 if !nested.is_empty() {
358 *self = IndexTrie::IndexNode {
359 subscriber: None,
360 nested,
361 }
362 }
363 subscriber
364 }
365 };
366 };
367 match self {
368 Self::Empty | Self::Leaf(..) | Self::WildcardNode { .. } => None,
373 Self::IndexNode { ref mut nested, .. } => nested
374 .get_mut(*i)
375 .and_then(|nested| nested.as_mut().and_then(|nested| nested.take(path))),
376 }
377 }
378
379 #[instrument(level = "trace", skip_all)]
382 fn insert(&mut self, path: &[Option<usize>], sub: Subscriber) -> bool {
383 match self {
384 Self::Empty => {
385 *self = Self::from((path, sub));
386 true
387 }
388 Self::Leaf(..) => {
389 let Some((i, path)) = path.split_first() else {
390 return false;
391 };
392 let Self::Leaf(subscriber) = mem::take(self) else {
393 return false;
394 };
395 if let Some(i) = i {
396 let n = i.saturating_add(1);
397 let mut nested = Vec::with_capacity(n);
398 nested.resize_with(n, Option::default);
399 nested[*i] = Some(Self::from((path, sub)));
400 *self = Self::IndexNode {
401 subscriber: Some(subscriber),
402 nested,
403 };
404 } else {
405 *self = Self::WildcardNode {
406 subscriber: Some(subscriber),
407 nested: Some(Box::new(Self::from((path, sub)))),
408 };
409 }
410 true
411 }
412 Self::WildcardNode {
413 ref mut subscriber,
414 ref mut nested,
415 } => match (&subscriber, path) {
416 (None, []) => {
417 *subscriber = Some(sub);
418 true
419 }
420 (_, [None, path @ ..]) => {
421 if let Some(nested) = nested {
422 nested.insert(path, sub)
423 } else {
424 *nested = Some(Box::new(Self::from((path, sub))));
425 true
426 }
427 }
428 _ => false,
429 },
430 Self::IndexNode {
431 ref mut subscriber,
432 ref mut nested,
433 } => match (&subscriber, path) {
434 (None, []) => {
435 *subscriber = Some(sub);
436 true
437 }
438 (_, [Some(i), path @ ..]) => {
439 let cap = i.saturating_add(1);
440 if nested.len() < cap {
441 nested.resize_with(cap, Option::default);
442 }
443 let nested = &mut nested[*i];
444 if let Some(nested) = nested {
445 nested.insert(path, sub)
446 } else {
447 *nested = Some(Self::from((path, sub)));
448 true
449 }
450 }
451 _ => false,
452 },
453 }
454 }
455}
456
457pub struct Reader {
458 buffer: Bytes,
459 incoming: Option<Subscriber>,
460 nested: Arc<std::sync::Mutex<IndexTrie>>,
461 path: Box<[usize]>,
462}
463
464impl wrpc_transport::Index<Self> for Reader {
465 #[instrument(level = "trace", skip(self))]
466 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
467 ensure!(!path.is_empty());
468 trace!("locking index tree");
469 let mut nested = self
470 .nested
471 .lock()
472 .map_err(|err| anyhow!(err.to_string()).context("failed to lock map"))?;
473 trace!("taking index subscription");
474 let mut p = self.path.to_vec();
475 p.extend_from_slice(path);
476 let incoming = nested.take(&p);
477 Ok(Self {
478 buffer: Bytes::default(),
479 incoming,
480 nested: Arc::clone(&self.nested),
481 path: p.into_boxed_slice(),
482 })
483 }
484}
485
486impl AsyncRead for Reader {
487 #[instrument(level = "trace", skip_all, ret)]
488 fn poll_read(
489 mut self: Pin<&mut Self>,
490 cx: &mut Context<'_>,
491 buf: &mut ReadBuf<'_>,
492 ) -> Poll<std::io::Result<()>> {
493 let cap = buf.remaining();
494 if cap == 0 {
495 trace!("attempt to read empty buffer");
496 return Poll::Ready(Ok(()));
497 }
498
499 if !self.buffer.is_empty() {
500 if self.buffer.len() > cap {
501 trace!(cap, len = self.buffer.len(), "reading part of buffer");
502 buf.put_slice(&self.buffer.split_to(cap));
503 } else {
504 trace!(cap, len = self.buffer.len(), "reading full buffer");
505 buf.put_slice(&mem::take(&mut self.buffer));
506 }
507 return Poll::Ready(Ok(()));
508 }
509 let Some(incoming) = self.incoming.as_mut() else {
510 return Poll::Ready(Err(std::io::Error::new(
511 std::io::ErrorKind::NotFound,
512 format!("subscription not found for path {:?}", self.path),
513 )));
514 };
515 trace!("polling for next message");
516 match incoming.poll_next_unpin(cx) {
517 Poll::Ready(Some(Message { mut payload, .. })) => {
518 trace!(?payload, "received message");
519 if payload.is_empty() {
520 trace!("received stream shutdown message");
521 return Poll::Ready(Ok(()));
522 }
523 if payload.len() > cap {
524 trace!(len = payload.len(), cap, "partially reading the message");
525 buf.put_slice(&payload.split_to(cap));
526 self.buffer = payload;
527 } else {
528 trace!(len = payload.len(), cap, "filling the buffer with payload");
529 buf.put_slice(&payload);
530 }
531 Poll::Ready(Ok(()))
532 }
533 Poll::Ready(None) => {
534 trace!("subscription finished");
535 Poll::Ready(Ok(()))
536 }
537 Poll::Pending => Poll::Pending,
538 }
539 }
540}
541
542#[derive(Clone, Debug)]
543pub struct SubjectWriter {
544 nats: async_nats::Client,
545 tx: Subject,
546 shutdown: bool,
547 tasks: Arc<JoinSet<()>>,
548}
549
550impl SubjectWriter {
551 fn new(nats: async_nats::Client, tx: Subject, tasks: Arc<JoinSet<()>>) -> Self {
552 Self {
553 nats,
554 tx,
555 shutdown: false,
556 tasks,
557 }
558 }
559}
560
561impl wrpc_transport::Index<Self> for SubjectWriter {
562 #[instrument(level = "trace", skip(self))]
563 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
564 ensure!(!path.is_empty());
565 let tx = Subject::from(index_path(self.tx.as_str(), path));
566 Ok(Self {
567 nats: self.nats.clone(),
568 tx,
569 shutdown: false,
570 tasks: Arc::clone(&self.tasks),
571 })
572 }
573}
574
575impl AsyncWrite for SubjectWriter {
576 #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str(), buf = format!("{buf:02x?}")))]
577 fn poll_write(
578 mut self: Pin<&mut Self>,
579 cx: &mut Context<'_>,
580 mut buf: &[u8],
581 ) -> Poll<std::io::Result<usize>> {
582 trace!("polling for readiness");
583 match self.nats.poll_ready_unpin(cx) {
584 Poll::Pending => return Poll::Pending,
585 Poll::Ready(Err(err)) => {
586 return Poll::Ready(Err(std::io::Error::new(
587 std::io::ErrorKind::BrokenPipe,
588 err,
589 )))
590 }
591 Poll::Ready(Ok(())) => {}
592 }
593 let ServerInfo { max_payload, .. } = self.nats.server_info();
594 if max_payload == 0 {
595 return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
596 }
597 if buf.len() > max_payload {
598 (buf, _) = buf.split_at(max_payload);
599 }
600 trace!("starting send");
601 let subject = self.tx.clone();
602 match self.nats.start_send_unpin(OutboundMessage {
603 subject,
604 payload: Bytes::copy_from_slice(buf),
605 reply: None,
606 headers: None,
607 }) {
608 Ok(()) => Poll::Ready(Ok(buf.len())),
609 Err(err) => Poll::Ready(Err(std::io::Error::new(
610 std::io::ErrorKind::BrokenPipe,
611 err,
612 ))),
613 }
614 }
615
616 #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
617 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
618 trace!("flushing");
619 self.nats
620 .poll_flush_unpin(cx)
621 .map_err(|_| std::io::ErrorKind::BrokenPipe.into())
622 }
623
624 #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
625 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
626 trace!("writing stream shutdown message");
627 ready!(self.as_mut().poll_write(cx, &[]))?;
628 self.shutdown = true;
629 Poll::Ready(Ok(()))
630 }
631}
632
633impl Drop for SubjectWriter {
634 fn drop(&mut self) {
635 if !self.shutdown {
636 let nats = self.nats.clone();
637 let subject = mem::replace(&mut self.tx, Subject::from_static(""));
638 let tasks = Arc::clone(&self.tasks);
639 spawn_async(async move {
640 trace!("writing stream shutdown message");
641 if let Err(err) = nats.publish(subject, Bytes::default()).await {
642 warn!(?err, "failed to publish stream shutdown message");
643 }
644 drop(tasks);
645 });
646 }
647 }
648}
649
650#[derive(Default)]
651pub enum RootParamWriter {
652 #[default]
653 Corrupted,
654 Handshaking {
655 nats: async_nats::Client,
656 sub: Subscriber,
657 indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
658 buffer: Bytes,
659 tasks: Arc<JoinSet<()>>,
660 },
661 Draining {
662 tx: SubjectWriter,
663 buffer: Bytes,
664 },
665 Active(SubjectWriter),
666}
667
668impl RootParamWriter {
669 fn new(
670 nats: async_nats::Client,
671 sub: Subscriber,
672 buffer: Bytes,
673 tasks: Arc<JoinSet<()>>,
674 ) -> Self {
675 Self::Handshaking {
676 nats,
677 sub,
678 indexed: std::sync::Mutex::default(),
679 buffer,
680 tasks,
681 }
682 }
683}
684
685impl RootParamWriter {
686 #[instrument(level = "trace", skip_all, ret)]
687 fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
688 match &mut *self {
689 Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
690 Self::Handshaking { sub, .. } => {
691 trace!("polling for handshake response");
692 match sub.poll_next_unpin(cx) {
693 Poll::Ready(Some(Message {
694 status: Some(StatusCode::NO_RESPONDERS),
695 ..
696 })) => Poll::Ready(Err(std::io::ErrorKind::NotConnected.into())),
697 Poll::Ready(Some(Message {
698 status: Some(StatusCode::TIMEOUT),
699 ..
700 })) => Poll::Ready(Err(std::io::ErrorKind::TimedOut.into())),
701 Poll::Ready(Some(Message {
702 status: Some(StatusCode::REQUEST_TERMINATED),
703 ..
704 })) => Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
705 Poll::Ready(Some(Message {
706 status: Some(code),
707 description,
708 ..
709 })) if !code.is_success() => Poll::Ready(Err(std::io::Error::other(
710 if let Some(description) = description {
711 format!("received a response with code `{code}` ({description})")
712 } else {
713 format!("received a response with code `{code}`")
714 },
715 ))),
716 Poll::Ready(Some(Message {
717 reply: Some(tx), ..
718 })) => {
719 let Self::Handshaking {
720 nats,
721 indexed,
722 buffer,
723 tasks,
724 ..
725 } = mem::take(&mut *self)
726 else {
727 return Poll::Ready(Err(corrupted_memory_error()));
728 };
729 let tx = SubjectWriter::new(nats, Subject::from(param_subject(&tx)), tasks);
730 let indexed = indexed
731 .into_inner()
732 .map_err(|err| std::io::Error::other(err.to_string()))?;
733 for (path, tx_tx) in indexed {
734 let tx = tx.index(&path).map_err(std::io::Error::other)?;
735 tx_tx.send(tx).map_err(|_| {
736 std::io::Error::from(std::io::ErrorKind::BrokenPipe)
737 })?;
738 }
739 trace!("handshake succeeded");
740 if buffer.is_empty() {
741 *self = Self::Active(tx);
742 Poll::Ready(Ok(()))
743 } else {
744 *self = Self::Draining { tx, buffer };
745 self.poll_active(cx)
746 }
747 }
748 Poll::Ready(Some(..)) => Poll::Ready(Err(std::io::Error::new(
749 std::io::ErrorKind::InvalidInput,
750 "peer did not specify a reply subject",
751 ))),
752 Poll::Ready(None) => {
753 *self = Self::Corrupted;
754 Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))
755 }
756 Poll::Pending => Poll::Pending,
757 }
758 }
759 Self::Draining { tx, buffer } => {
760 let mut tx = pin!(tx);
761 while !buffer.is_empty() {
762 trace!(?tx.tx, "draining parameter buffer");
763 match tx.as_mut().poll_write(cx, buffer) {
764 Poll::Ready(Ok(n)) => {
765 buffer.advance(n);
766 }
767 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
768 Poll::Pending => return Poll::Pending,
769 }
770 }
771 let Self::Draining { tx, .. } = mem::take(&mut *self) else {
772 return Poll::Ready(Err(corrupted_memory_error()));
773 };
774 trace!("parameter buffer draining succeeded");
775 *self = Self::Active(tx);
776 Poll::Ready(Ok(()))
777 }
778 Self::Active(..) => Poll::Ready(Ok(())),
779 }
780 }
781}
782
783impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
784 #[instrument(level = "trace", skip(self))]
785 fn index(&self, path: &[usize]) -> anyhow::Result<IndexedParamWriter> {
786 ensure!(!path.is_empty());
787 match self {
788 Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
789 Self::Handshaking { indexed, .. } => {
790 let (tx_tx, tx_rx) = oneshot::channel();
791 let mut indexed = indexed
792 .lock()
793 .map_err(|err| std::io::Error::other(err.to_string()))?;
794 indexed.push((path.to_vec(), tx_tx));
795 Ok(IndexedParamWriter::Handshaking {
796 tx_rx,
797 indexed: std::sync::Mutex::default(),
798 })
799 }
800 Self::Draining { tx, .. } | Self::Active(tx) => {
801 tx.index(path).map(IndexedParamWriter::Active)
802 }
803 }
804 }
805}
806
807impl AsyncWrite for RootParamWriter {
808 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
809 fn poll_write(
810 mut self: Pin<&mut Self>,
811 cx: &mut Context<'_>,
812 buf: &[u8],
813 ) -> Poll<std::io::Result<usize>> {
814 match self.as_mut().poll_active(cx)? {
815 Poll::Ready(()) => {
816 let Self::Active(tx) = &mut *self else {
817 return Poll::Ready(Err(corrupted_memory_error()));
818 };
819 trace!("writing buffer");
820 pin!(tx).poll_write(cx, buf)
821 }
822 Poll::Pending => Poll::Pending,
823 }
824 }
825
826 #[instrument(level = "trace", skip_all, ret)]
827 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
828 match self.as_mut().poll_active(cx)? {
829 Poll::Ready(()) => {
830 let Self::Active(tx) = &mut *self else {
831 return Poll::Ready(Err(corrupted_memory_error()));
832 };
833 trace!("flushing");
834 pin!(tx).poll_flush(cx)
835 }
836 Poll::Pending => Poll::Pending,
837 }
838 }
839
840 #[instrument(level = "trace", skip_all, ret)]
841 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
842 match self.as_mut().poll_active(cx)? {
843 Poll::Ready(()) => {
844 let Self::Active(tx) = &mut *self else {
845 return Poll::Ready(Err(corrupted_memory_error()));
846 };
847 trace!("shutting down");
848 pin!(tx).poll_shutdown(cx)
849 }
850 Poll::Pending => Poll::Pending,
851 }
852 }
853}
854
855#[derive(Debug, Default)]
856pub enum IndexedParamWriter {
857 #[default]
858 Corrupted,
859 Handshaking {
860 tx_rx: oneshot::Receiver<SubjectWriter>,
861 indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
862 },
863 Active(SubjectWriter),
864}
865
866impl IndexedParamWriter {
867 #[instrument(level = "trace", skip_all, ret)]
868 fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
869 match &mut *self {
870 Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
871 Self::Handshaking { tx_rx, .. } => {
872 trace!("polling for handshake");
873 match pin!(tx_rx).poll(cx) {
874 Poll::Ready(Ok(tx)) => {
875 let Self::Handshaking { indexed, .. } = mem::take(&mut *self) else {
876 return Poll::Ready(Err(corrupted_memory_error()));
877 };
878 let indexed = indexed
879 .into_inner()
880 .map_err(|err| std::io::Error::other(err.to_string()))?;
881 for (path, tx_tx) in indexed {
882 let tx = tx.index(&path).map_err(std::io::Error::other)?;
883 tx_tx.send(tx).map_err(|_| {
884 std::io::Error::from(std::io::ErrorKind::BrokenPipe)
885 })?;
886 }
887 *self = Self::Active(tx);
888 Poll::Ready(Ok(()))
889 }
890 Poll::Ready(Err(..)) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
891 Poll::Pending => Poll::Pending,
892 }
893 }
894 Self::Active(..) => Poll::Ready(Ok(())),
895 }
896 }
897}
898
899impl wrpc_transport::Index<Self> for IndexedParamWriter {
900 #[instrument(level = "trace", skip_all)]
901 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
902 ensure!(!path.is_empty());
903 match self {
904 Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
905 Self::Handshaking { indexed, .. } => {
906 let (tx_tx, tx_rx) = oneshot::channel();
907 let mut indexed = indexed
908 .lock()
909 .map_err(|err| std::io::Error::other(err.to_string()))?;
910 indexed.push((path.to_vec(), tx_tx));
911 Ok(Self::Handshaking {
912 tx_rx,
913 indexed: std::sync::Mutex::default(),
914 })
915 }
916 Self::Active(tx) => tx.index(path).map(Self::Active),
917 }
918 }
919}
920
921impl AsyncWrite for IndexedParamWriter {
922 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
923 fn poll_write(
924 mut self: Pin<&mut Self>,
925 cx: &mut Context<'_>,
926 buf: &[u8],
927 ) -> Poll<std::io::Result<usize>> {
928 match self.as_mut().poll_active(cx)? {
929 Poll::Ready(()) => {
930 let Self::Active(tx) = &mut *self else {
931 return Poll::Ready(Err(corrupted_memory_error()));
932 };
933 trace!("writing buffer");
934 pin!(tx).poll_write(cx, buf)
935 }
936 Poll::Pending => Poll::Pending,
937 }
938 }
939
940 #[instrument(level = "trace", skip_all, ret)]
941 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
942 match self.as_mut().poll_active(cx)? {
943 Poll::Ready(()) => {
944 let Self::Active(tx) = &mut *self else {
945 return Poll::Ready(Err(corrupted_memory_error()));
946 };
947 trace!("flushing");
948 pin!(tx).poll_flush(cx)
949 }
950 Poll::Pending => Poll::Pending,
951 }
952 }
953
954 #[instrument(level = "trace", skip_all, ret)]
955 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
956 match self.as_mut().poll_active(cx)? {
957 Poll::Ready(()) => {
958 let Self::Active(tx) = &mut *self else {
959 return Poll::Ready(Err(corrupted_memory_error()));
960 };
961 trace!("shutting down");
962 pin!(tx).poll_shutdown(cx)
963 }
964 Poll::Pending => Poll::Pending,
965 }
966 }
967}
968
969pub enum ParamWriter {
970 Root(RootParamWriter),
971 Nested(IndexedParamWriter),
972}
973
974impl wrpc_transport::Index<Self> for ParamWriter {
975 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
976 ensure!(!path.is_empty());
977 match self {
978 ParamWriter::Root(w) => w.index(path),
979 ParamWriter::Nested(w) => w.index(path),
980 }
981 .map(Self::Nested)
982 }
983}
984
985impl AsyncWrite for ParamWriter {
986 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
987 fn poll_write(
988 mut self: Pin<&mut Self>,
989 cx: &mut Context<'_>,
990 buf: &[u8],
991 ) -> Poll<std::io::Result<usize>> {
992 match &mut *self {
993 ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
994 ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
995 }
996 }
997
998 #[instrument(level = "trace", skip_all, ret)]
999 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1000 match &mut *self {
1001 ParamWriter::Root(w) => pin!(w).poll_flush(cx),
1002 ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
1003 }
1004 }
1005
1006 #[instrument(level = "trace", skip_all, ret)]
1007 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1008 match &mut *self {
1009 ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
1010 ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
1011 }
1012 }
1013}
1014
1015impl wrpc_transport::Invoke for Client {
1016 type Context = Option<HeaderMap>;
1017 type Outgoing = ParamWriter;
1018 type Incoming = Reader;
1019
1020 #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
1021 async fn invoke<P: AsRef<[Option<usize>]> + Send + Sync>(
1022 &self,
1023 cx: Self::Context,
1024 instance: &str,
1025 func: &str,
1026 mut params: Bytes,
1027 paths: impl AsRef<[P]> + Send,
1028 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)> {
1029 let paths = paths.as_ref();
1030 let mut cmds = Vec::with_capacity(paths.len().saturating_add(2));
1031
1032 let rx = Subject::from(new_inbox(&self.inbox));
1033 let (handshake_tx, handshake_rx) = mpsc::channel(1);
1034 cmds.push(Command::Subscribe(rx.clone(), handshake_tx));
1035
1036 let result = Subject::from(result_subject(&rx));
1037 let (result_tx, result_rx) = mpsc::channel(16);
1038 cmds.push(Command::Subscribe(result.clone(), result_tx));
1039
1040 let nested = paths.iter().map(|path| {
1041 let (tx, rx) = mpsc::channel(16);
1042 let subject = Subject::from(subscribe_path(&result, path.as_ref()));
1043 cmds.push(Command::Subscribe(subject.clone(), tx));
1044 Subscriber {
1045 rx: ReceiverStream::new(rx),
1046 commands: self.commands.clone(),
1047 subject,
1048 tasks: Arc::clone(&self.tasks),
1049 }
1050 });
1051 let nested: IndexTrie = zip(paths.iter(), nested).collect();
1052 ensure!(
1053 paths.is_empty() == nested.is_empty(),
1054 "failed to construct subscription tree"
1055 );
1056
1057 self.commands
1058 .send(Command::Batch(cmds.into_boxed_slice()))
1059 .await
1060 .context("failed to subscribe")?;
1061
1062 let ServerInfo {
1063 mut max_payload, ..
1064 } = self.nats.server_info();
1065 max_payload = max_payload.saturating_sub(rx.len());
1066 let param_tx = Subject::from(invocation_subject(&self.prefix, instance, func));
1067 if let Some(headers) = cx {
1068 max_payload = max_payload
1070 .saturating_sub(b"NATS/1.0\r\n".len())
1071 .saturating_sub(b"\r\n".len());
1072 for (k, vs) in headers.iter() {
1073 let k: &[u8] = k.as_ref();
1074 for v in vs {
1075 max_payload = max_payload
1076 .saturating_sub(k.len())
1077 .saturating_sub(b": ".len())
1078 .saturating_sub(v.as_str().len())
1079 .saturating_sub(b"\r\n".len());
1080 }
1081 }
1082 trace!("publishing handshake");
1083 self.nats
1084 .publish_with_reply_and_headers(
1085 param_tx,
1086 rx.clone(),
1087 headers,
1088 params.split_to(max_payload.min(params.len())),
1089 )
1090 .await
1091 } else {
1092 trace!("publishing handshake");
1093 self.nats
1094 .publish_with_reply(
1095 param_tx,
1096 rx.clone(),
1097 params.split_to(max_payload.min(params.len())),
1098 )
1099 .await
1100 }
1101 .context("failed to publish handshake")?;
1102 let nats = Arc::clone(&self.nats);
1103 tokio::spawn(async move {
1104 if let Err(err) = nats.flush().await {
1105 error!(?err, "failed to flush");
1106 }
1107 });
1108 Ok((
1109 ParamWriter::Root(RootParamWriter::new(
1110 (*self.nats).clone(),
1111 Subscriber {
1112 rx: ReceiverStream::new(handshake_rx),
1113 commands: self.commands.clone(),
1114 subject: rx,
1115 tasks: Arc::clone(&self.tasks),
1116 },
1117 params,
1118 Arc::clone(&self.tasks),
1119 )),
1120 Reader {
1121 buffer: Bytes::default(),
1122 incoming: Some(Subscriber {
1123 rx: ReceiverStream::new(result_rx),
1124 commands: self.commands.clone(),
1125 subject: result,
1126 tasks: Arc::clone(&self.tasks),
1127 }),
1128 nested: Arc::new(std::sync::Mutex::new(nested)),
1129 path: Box::default(),
1130 },
1131 ))
1132 }
1133}
1134
1135async fn handle_message(
1136 nats: &async_nats::Client,
1137 rx: Subject,
1138 commands: mpsc::Sender<Command>,
1139 async_nats::Message {
1140 reply: tx,
1141 payload,
1142 headers,
1143 ..
1144 }: async_nats::Message,
1145 paths: &[Box<[Option<usize>]>],
1146 tasks: Arc<JoinSet<()>>,
1147) -> anyhow::Result<(Option<HeaderMap>, SubjectWriter, Reader)> {
1148 let tx = tx.context("peer did not specify a reply subject")?;
1149
1150 let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
1151
1152 let param = Subject::from(param_subject(&rx));
1153 let (param_tx, param_rx) = mpsc::channel(16);
1154 cmds.push(Command::Subscribe(param.clone(), param_tx));
1155
1156 let nested = paths.iter().map(|path| {
1157 let (tx, rx) = mpsc::channel(16);
1158 let subject = Subject::from(subscribe_path(¶m, path.as_ref()));
1159 cmds.push(Command::Subscribe(subject.clone(), tx));
1160 Subscriber {
1161 rx: ReceiverStream::new(rx),
1162 commands: commands.clone(),
1163 subject,
1164 tasks: Arc::clone(&tasks),
1165 }
1166 });
1167 let nested: IndexTrie = zip(paths.iter(), nested).collect();
1168 ensure!(
1169 paths.is_empty() == nested.is_empty(),
1170 "failed to construct subscription tree"
1171 );
1172
1173 commands
1174 .send(Command::Batch(cmds.into_boxed_slice()))
1175 .await
1176 .context("failed to subscribe")?;
1177
1178 trace!("publishing handshake response");
1179 nats.publish_with_reply(tx.clone(), rx, Bytes::default())
1180 .await
1181 .context("failed to publish handshake accept")?;
1182 Ok((
1183 headers,
1184 SubjectWriter::new(
1185 nats.clone(),
1186 Subject::from(result_subject(&tx)),
1187 Arc::clone(&tasks),
1188 ),
1189 Reader {
1190 buffer: payload,
1191 incoming: Some(Subscriber {
1192 rx: ReceiverStream::new(param_rx),
1193 commands,
1194 subject: param,
1195 tasks,
1196 }),
1197 nested: Arc::new(std::sync::Mutex::new(nested)),
1198 path: Box::default(),
1199 },
1200 ))
1201}
1202
1203impl wrpc_transport::Serve for Client {
1204 type Context = Option<HeaderMap>;
1205 type Outgoing = SubjectWriter;
1206 type Incoming = Reader;
1207
1208 #[instrument(level = "trace", skip(self, paths))]
1209 async fn serve(
1210 &self,
1211 instance: &str,
1212 func: &str,
1213 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1214 ) -> anyhow::Result<
1215 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
1216 > {
1217 let subject = invocation_subject(&self.prefix, instance, func);
1218 let sub = if let Some(group) = &self.queue_group {
1219 debug!(subject, ?group, "queue-subscribing on invocation subject");
1220 self.nats
1221 .queue_subscribe(subject, group.to_string())
1222 .await?
1223 } else {
1224 debug!(subject, "subscribing on invocation subject");
1225 self.nats.subscribe(subject).await?
1226 };
1227 let nats = Arc::clone(&self.nats);
1228 let paths = paths.into();
1229 let commands = self.commands.clone();
1230 let inbox = Arc::clone(&self.inbox);
1231 let tasks = Arc::clone(&self.tasks);
1232 Ok(sub.then(move |msg| {
1233 let tasks = Arc::clone(&tasks);
1234 let nats = Arc::clone(&nats);
1235 let paths = Arc::clone(&paths);
1236 let commands = commands.clone();
1237 let rx = Subject::from(new_inbox(&inbox));
1238 async move { handle_message(&nats, rx, commands, msg, &paths, tasks).await }
1239 }))
1240 }
1241}