use crate::Signal;
use async_io::Async;
use futures_core::ready;
use futures_io::AsyncRead;
use std::io::{self, prelude::*};
use std::mem;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::pin::Pin;
use std::task::{Context, Poll};
const BUFFER_LEN: usize = mem::size_of::<std::os::raw::c_int>();
#[derive(Debug)]
pub(super) struct Notifier {
read: Async<UnixStream>,
write: UnixStream,
}
impl Notifier {
pub(super) fn new() -> io::Result<Self> {
let (read, write) = UnixStream::pair()?;
let read = Async::new(read)?;
write.set_nonblocking(true)?;
Ok(Self { read, write })
}
pub(super) fn add_signal(
&mut self,
signal: Signal,
) -> io::Result<impl Fn() + Send + Sync + 'static> {
let number = signal.number();
let write = self.write.try_clone()?;
Ok(move || {
let bytes = number.to_ne_bytes();
let _ = (&write).write(&bytes);
})
}
pub(super) fn remove_signal(&mut self, _signal: Signal) -> io::Result<()> {
Ok(())
}
pub(super) fn poll_next(&self, cx: &mut Context<'_>) -> Poll<io::Result<Signal>> {
let mut buffer = [0; BUFFER_LEN];
let mut buffer_len = 0;
loop {
if buffer_len >= BUFFER_LEN {
break;
}
let buf_range = buffer_len..BUFFER_LEN;
let res = ready!(Pin::new(&mut &self.read).poll_read(cx, &mut buffer[buf_range]));
match res {
Ok(0) => return Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof))),
Ok(n) => buffer_len += n,
Err(e) => return Poll::Ready(Err(e)),
}
}
let number = std::os::raw::c_int::from_ne_bytes(buffer);
let signal = match Signal::from_number(number) {
Some(signal) => signal,
None => return Poll::Ready(Err(io::Error::from(io::ErrorKind::InvalidData))),
};
Poll::Ready(Ok(signal))
}
}
impl AsRawFd for Notifier {
fn as_raw_fd(&self) -> RawFd {
self.read.as_raw_fd()
}
}
impl AsFd for Notifier {
fn as_fd(&self) -> BorrowedFd<'_> {
self.read.as_fd()
}
}