wit_deps/
digest.rs

1use core::fmt;
2use core::pin::Pin;
3use core::task::{Context, Poll};
4
5use futures::{AsyncRead, AsyncWrite};
6use hex::FromHex;
7use serde::ser::SerializeStruct;
8use serde::{de, Deserialize, Serialize};
9use sha2::{Digest as _, Sha256, Sha512};
10
11/// A resource digest
12#[derive(Clone, Debug, Eq, Hash, PartialEq)]
13pub struct Digest {
14    /// Sha256 digest of a resource
15    pub sha256: [u8; 32],
16    /// Sha512 digest of a resource
17    pub sha512: [u8; 64],
18}
19
20impl<'de> Deserialize<'de> for Digest {
21    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22    where
23        D: serde::Deserializer<'de>,
24    {
25        const FIELDS: [&str; 2] = ["sha256", "sha512"];
26
27        struct Visitor;
28        impl<'de> de::Visitor<'de> for Visitor {
29            type Value = Digest;
30
31            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
32                formatter.write_str("a resource digest")
33            }
34
35            fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
36            where
37                V: de::MapAccess<'de>,
38            {
39                let mut sha256 = None;
40                let mut sha512 = None;
41                while let Some((k, v)) = map.next_entry::<String, String>()? {
42                    match k.as_ref() {
43                        "sha256" => {
44                            if sha256.is_some() {
45                                return Err(de::Error::duplicate_field("sha256"));
46                            }
47                            sha256 = FromHex::from_hex(v).map(Some).map_err(|e| {
48                                de::Error::custom(format!("invalid `sha256` field value: {e}"))
49                            })?;
50                        }
51                        "sha512" => {
52                            if sha512.is_some() {
53                                return Err(de::Error::duplicate_field("sha512"));
54                            }
55                            sha512 = FromHex::from_hex(v).map(Some).map_err(|e| {
56                                de::Error::custom(format!("invalid `sha512` field value: {e}"))
57                            })?;
58                        }
59                        k => return Err(de::Error::unknown_field(k, &FIELDS)),
60                    }
61                }
62                let sha256 = sha256.ok_or_else(|| de::Error::missing_field("sha256"))?;
63                let sha512 = sha512.ok_or_else(|| de::Error::missing_field("sha512"))?;
64                Ok(Digest { sha256, sha512 })
65            }
66        }
67        deserializer.deserialize_struct("Entry", &FIELDS, Visitor)
68    }
69}
70
71impl Serialize for Digest {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        let mut state = serializer.serialize_struct("Digest", 2)?;
77        state.serialize_field("sha256", &hex::encode(self.sha256))?;
78        state.serialize_field("sha512", &hex::encode(self.sha512))?;
79        state.end()
80    }
81}
82
83/// A reader wrapper, which hashes the bytes read
84pub struct Reader<T> {
85    inner: T,
86    sha256: Sha256,
87    sha512: Sha512,
88}
89
90impl<T: AsyncRead + Unpin> AsyncRead for Reader<T> {
91    fn poll_read(
92        mut self: Pin<&mut Self>,
93        cx: &mut Context<'_>,
94        buf: &mut [u8],
95    ) -> Poll<std::io::Result<usize>> {
96        Pin::new(&mut self.inner).poll_read(cx, buf).map_ok(|n| {
97            self.sha256.update(&buf[..n]);
98            self.sha512.update(&buf[..n]);
99            n
100        })
101    }
102}
103
104impl<T> From<T> for Reader<T> {
105    fn from(inner: T) -> Self {
106        Self {
107            inner,
108            sha256: Sha256::new(),
109            sha512: Sha512::new(),
110        }
111    }
112}
113
114impl<T> From<Reader<T>> for Digest {
115    fn from(hashed: Reader<T>) -> Self {
116        let sha256 = hashed.sha256.finalize().into();
117        let sha512 = hashed.sha512.finalize().into();
118        Self { sha256, sha512 }
119    }
120}
121
122/// A writer wrapper, which hashes the bytes written
123pub struct Writer<T> {
124    inner: T,
125    sha256: Sha256,
126    sha512: Sha512,
127}
128
129impl<T: AsyncWrite + Unpin> AsyncWrite for Writer<T> {
130    fn poll_write(
131        mut self: Pin<&mut Self>,
132        cx: &mut Context<'_>,
133        buf: &[u8],
134    ) -> Poll<std::io::Result<usize>> {
135        Pin::new(&mut self.inner).poll_write(cx, buf).map_ok(|n| {
136            self.sha256.update(&buf[..n]);
137            self.sha512.update(&buf[..n]);
138            n
139        })
140    }
141
142    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
143        Pin::new(&mut self.inner).poll_flush(cx)
144    }
145
146    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
147        Pin::new(&mut self.inner).poll_close(cx)
148    }
149}
150
151impl<T> From<T> for Writer<T> {
152    fn from(inner: T) -> Self {
153        Self {
154            inner,
155            sha256: Sha256::new(),
156            sha512: Sha512::new(),
157        }
158    }
159}
160
161impl<T> From<Writer<T>> for Digest {
162    fn from(hashed: Writer<T>) -> Self {
163        let sha256 = hashed.sha256.finalize().into();
164        let sha512 = hashed.sha512.finalize().into();
165        Self { sha256, sha512 }
166    }
167}