wit_deps/
digest.rs

1use core::fmt;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4
5use hex::FromHex;
6use serde::ser::SerializeStruct;
7use serde::{de, Deserialize, Serialize};
8use sha2::{Digest as _, Sha256, Sha512};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11/// A resource digest
12#[derive(Clone, Debug, Eq, Hash, PartialEq)]
13pub struct Digest {
14    /// Sha256 digest of a resource
15    pub sha256: [u8; 32],
16    /// Sha512 digest of a resource
17    pub sha512: [u8; 64],
18}
19
20impl<'de> Deserialize<'de> for Digest {
21    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22    where
23        D: serde::Deserializer<'de>,
24    {
25        const FIELDS: [&str; 2] = ["sha256", "sha512"];
26
27        struct Visitor;
28        impl<'de> de::Visitor<'de> for Visitor {
29            type Value = Digest;
30
31            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
32                formatter.write_str("a resource digest")
33            }
34
35            fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
36            where
37                V: de::MapAccess<'de>,
38            {
39                let mut sha256 = None;
40                let mut sha512 = None;
41                while let Some((k, v)) = map.next_entry::<String, String>()? {
42                    match k.as_ref() {
43                        "sha256" => {
44                            if sha256.is_some() {
45                                return Err(de::Error::duplicate_field("sha256"));
46                            }
47                            sha256 = FromHex::from_hex(v).map(Some).map_err(|e| {
48                                de::Error::custom(format!("invalid `sha256` field value: {e}"))
49                            })?;
50                        }
51                        "sha512" => {
52                            if sha512.is_some() {
53                                return Err(de::Error::duplicate_field("sha512"));
54                            }
55                            sha512 = FromHex::from_hex(v).map(Some).map_err(|e| {
56                                de::Error::custom(format!("invalid `sha512` field value: {e}"))
57                            })?;
58                        }
59                        k => return Err(de::Error::unknown_field(k, &FIELDS)),
60                    }
61                }
62                let sha256 = sha256.ok_or_else(|| de::Error::missing_field("sha256"))?;
63                let sha512 = sha512.ok_or_else(|| de::Error::missing_field("sha512"))?;
64                Ok(Digest { sha256, sha512 })
65            }
66        }
67        deserializer.deserialize_struct("Entry", &FIELDS, Visitor)
68    }
69}
70
71impl Serialize for Digest {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        let mut state = serializer.serialize_struct("Digest", 2)?;
77        state.serialize_field("sha256", &hex::encode(self.sha256))?;
78        state.serialize_field("sha512", &hex::encode(self.sha512))?;
79        state.end()
80    }
81}
82
83/// A reader wrapper, which hashes the bytes read
84pub struct Reader<T> {
85    inner: T,
86    sha256: Sha256,
87    sha512: Sha512,
88}
89
90impl<T: AsyncRead + Unpin> AsyncRead for Reader<T> {
91    fn poll_read(
92        mut self: Pin<&mut Self>,
93        cx: &mut Context<'_>,
94        buf: &mut ReadBuf<'_>,
95    ) -> Poll<std::io::Result<()>> {
96        let n = buf.filled().len();
97        match Pin::new(&mut self.inner).poll_read(cx, buf) {
98            Poll::Ready(Ok(())) => {
99                let buf = buf.filled();
100                self.sha256.update(&buf[n..]);
101                self.sha512.update(&buf[n..]);
102                Poll::Ready(Ok(()))
103            }
104            other => other,
105        }
106    }
107}
108
109impl<T> From<T> for Reader<T> {
110    fn from(inner: T) -> Self {
111        Self {
112            inner,
113            sha256: Sha256::new(),
114            sha512: Sha512::new(),
115        }
116    }
117}
118
119impl<T> From<Reader<T>> for Digest {
120    fn from(hashed: Reader<T>) -> Self {
121        let sha256 = hashed.sha256.finalize().into();
122        let sha512 = hashed.sha512.finalize().into();
123        Self { sha256, sha512 }
124    }
125}
126
127/// A writer wrapper, which hashes the bytes written
128pub struct Writer<T> {
129    inner: T,
130    sha256: Sha256,
131    sha512: Sha512,
132}
133
134impl<T: AsyncWrite + Unpin> AsyncWrite for Writer<T> {
135    fn poll_write(
136        mut self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        buf: &[u8],
139    ) -> Poll<std::io::Result<usize>> {
140        Pin::new(&mut self.inner).poll_write(cx, buf).map_ok(|n| {
141            self.sha256.update(&buf[..n]);
142            self.sha512.update(&buf[..n]);
143            n
144        })
145    }
146
147    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
148        Pin::new(&mut self.inner).poll_flush(cx)
149    }
150
151    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
152        Pin::new(&mut self.inner).poll_shutdown(cx)
153    }
154}
155
156impl<T> From<T> for Writer<T> {
157    fn from(inner: T) -> Self {
158        Self {
159            inner,
160            sha256: Sha256::new(),
161            sha512: Sha512::new(),
162        }
163    }
164}
165
166impl<T> From<Writer<T>> for Digest {
167    fn from(hashed: Writer<T>) -> Self {
168        let sha256 = hashed.sha256.finalize().into();
169        let sha512 = hashed.sha512.finalize().into();
170        Self { sha256, sha512 }
171    }
172}