modio/types/
utils.rs

1use serde::de::{DeserializeOwned, Error, MapAccess};
2
3mod smallstr {
4    use std::fmt;
5
6    use serde::de::{Deserialize, Deserializer, Error, Visitor};
7    use serde::ser::{Serialize, Serializer};
8
9    #[derive(Clone, Copy, Eq, Hash, PartialEq)]
10    pub struct SmallStr<const LENGTH: usize> {
11        bytes: [u8; LENGTH],
12    }
13
14    impl<const LENGTH: usize> SmallStr<LENGTH> {
15        pub(crate) const fn from_str(input: &str) -> Option<Self> {
16            if input.len() > LENGTH {
17                return None;
18            }
19            Some(Self::from_bytes(input.as_bytes()))
20        }
21
22        pub(crate) const fn from_bytes(input: &[u8]) -> Self {
23            let mut bytes = [0; LENGTH];
24            let mut idx = 0;
25
26            while idx < input.len() {
27                bytes[idx] = input[idx];
28                idx += 1;
29            }
30
31            Self { bytes }
32        }
33
34        pub fn as_str(&self) -> &str {
35            std::str::from_utf8(&self.bytes)
36                .expect("invalid utf8 string")
37                .trim_end_matches('\0')
38        }
39    }
40
41    impl<const LENGTH: usize> fmt::Debug for SmallStr<LENGTH> {
42        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43            f.write_str(self.as_str())
44        }
45    }
46
47    impl<'de, const LENGTH: usize> Deserialize<'de> for SmallStr<LENGTH> {
48        fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
49            struct StrVisitor<const LENGTH: usize>;
50
51            impl<const LENGTH: usize> Visitor<'_> for StrVisitor<LENGTH> {
52                type Value = SmallStr<LENGTH>;
53
54                fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
55                    formatter.write_str("string")
56                }
57
58                fn visit_str<E: Error>(self, v: &str) -> Result<Self::Value, E> {
59                    SmallStr::from_str(v).ok_or_else(|| Error::custom("string is too long"))
60                }
61
62                fn visit_string<E: Error>(self, v: String) -> Result<Self::Value, E> {
63                    SmallStr::from_str(&v).ok_or_else(|| Error::custom("string is too long"))
64                }
65            }
66
67            deserializer.deserialize_any(StrVisitor::<LENGTH>)
68        }
69    }
70
71    impl<const LENGTH: usize> Serialize for SmallStr<LENGTH> {
72        fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
73            self.as_str().serialize(serializer)
74        }
75    }
76}
77pub use smallstr::SmallStr;
78
79pub mod url {
80    use std::fmt;
81
82    use serde::de::{Deserializer, Error, Visitor};
83    use url::Url;
84
85    struct UrlVisitor;
86
87    impl Visitor<'_> for UrlVisitor {
88        type Value = Url;
89
90        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
91            formatter.write_str("a string representing an URL")
92        }
93
94        fn visit_str<E: Error>(self, s: &str) -> Result<Self::Value, E> {
95            Url::parse(s).map_err(|err| Error::custom(format!("{err}: {s:?}")))
96        }
97    }
98
99    pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Url, D::Error> {
100        deserializer.deserialize_any(UrlVisitor)
101    }
102
103    pub mod opt {
104        use std::fmt;
105
106        use serde::de::{Deserializer, Error, Visitor};
107        use url::Url;
108
109        struct UrlVisitor;
110
111        impl<'de> Visitor<'de> for UrlVisitor {
112            type Value = Option<Url>;
113
114            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
115                formatter.write_str("an optional string representing an URL")
116            }
117
118            fn visit_some<D: Deserializer<'de>>(self, d: D) -> Result<Self::Value, D::Error> {
119                d.deserialize_any(super::UrlVisitor).map(Some)
120            }
121
122            fn visit_none<E: Error>(self) -> Result<Self::Value, E> {
123                Ok(None)
124            }
125
126            fn visit_unit<E: Error>(self) -> Result<Self::Value, E> {
127                Ok(None)
128            }
129        }
130
131        pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Url>, D::Error> {
132            d.deserialize_option(UrlVisitor)
133        }
134    }
135}
136
137pub trait DeserializeField<T: DeserializeOwned> {
138    fn deserialize_value<'de, A: MapAccess<'de>>(
139        &mut self,
140        name: &'static str,
141        map: &mut A,
142    ) -> Result<(), A::Error>;
143}
144
145impl<T: DeserializeOwned> DeserializeField<T> for Option<T> {
146    fn deserialize_value<'de, A>(&mut self, name: &'static str, map: &mut A) -> Result<(), A::Error>
147    where
148        A: MapAccess<'de>,
149    {
150        if self.is_some() {
151            return Err(A::Error::duplicate_field(name));
152        }
153        self.replace(map.next_value()?);
154        Ok(())
155    }
156}
157
158pub trait MissingField<T> {
159    fn missing_field<E: Error>(self, name: &'static str) -> Result<T, E>;
160}
161
162impl<T> MissingField<T> for Option<T> {
163    fn missing_field<E: Error>(self, name: &'static str) -> Result<T, E> {
164        self.ok_or_else(|| Error::missing_field(name))
165    }
166}