1#![allow(clippy::module_name_repetitions)]
12
13use std::{
14 borrow::Cow,
15 collections::BTreeSet,
16 iter::FromIterator,
17 ops::{Deref, DerefMut},
18 str::FromStr,
19};
20
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[error("Invalid scope format")]
27pub struct InvalidScope;
28
29#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
31pub struct ScopeToken(Cow<'static, str>);
32
33impl ScopeToken {
34 #[must_use]
37 pub const fn from_static(token: &'static str) -> Self {
38 Self(Cow::Borrowed(token))
39 }
40
41 #[must_use]
43 pub fn as_str(&self) -> &str {
44 self.0.as_ref()
45 }
46}
47
48pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
52
53pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
57
58pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
62
63pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
67
68pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
72
73pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
79
80fn nqchar(c: char) -> bool {
85 '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
86}
87
88impl FromStr for ScopeToken {
89 type Err = InvalidScope;
90
91 fn from_str(s: &str) -> Result<Self, Self::Err> {
92 if !s.is_empty() && s.chars().all(nqchar) {
97 Ok(ScopeToken(Cow::Owned(s.into())))
98 } else {
99 Err(InvalidScope)
100 }
101 }
102}
103
104impl Deref for ScopeToken {
105 type Target = str;
106
107 fn deref(&self) -> &Self::Target {
108 &self.0
109 }
110}
111
112impl std::fmt::Display for ScopeToken {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 self.0.fmt(f)
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct Scope(BTreeSet<ScopeToken>);
121
122impl Deref for Scope {
123 type Target = BTreeSet<ScopeToken>;
124
125 fn deref(&self) -> &Self::Target {
126 &self.0
127 }
128}
129
130impl DerefMut for Scope {
131 fn deref_mut(&mut self) -> &mut Self::Target {
132 &mut self.0
133 }
134}
135
136impl FromStr for Scope {
137 type Err = InvalidScope;
138
139 fn from_str(s: &str) -> Result<Self, Self::Err> {
140 let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
145 s.split(' ').map(ScopeToken::from_str).collect();
146
147 Ok(Self(scopes?))
148 }
149}
150
151impl Scope {
152 #[must_use]
154 pub fn is_empty(&self) -> bool {
155 self.0.is_empty()
157 }
158
159 #[must_use]
161 pub fn len(&self) -> usize {
162 self.0.len()
163 }
164
165 #[must_use]
167 pub fn contains(&self, token: &str) -> bool {
168 ScopeToken::from_str(token)
169 .map(|token| self.0.contains(&token))
170 .unwrap_or(false)
171 }
172
173 pub fn insert(&mut self, value: ScopeToken) -> bool {
177 self.0.insert(value)
178 }
179}
180
181impl std::fmt::Display for Scope {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 for (index, token) in self.0.iter().enumerate() {
184 if index == 0 {
185 write!(f, "{token}")?;
186 } else {
187 write!(f, " {token}")?;
188 }
189 }
190
191 Ok(())
192 }
193}
194
195impl Serialize for Scope {
196 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
197 where
198 S: serde::Serializer,
199 {
200 self.to_string().serialize(serializer)
201 }
202}
203
204impl<'de> Deserialize<'de> for Scope {
205 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
206 where
207 D: serde::Deserializer<'de>,
208 {
209 let scope: String = Deserialize::deserialize(deserializer)?;
211 Scope::from_str(&scope).map_err(serde::de::Error::custom)
212 }
213}
214
215impl FromIterator<ScopeToken> for Scope {
216 fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
217 Self(BTreeSet::from_iter(iter))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn parse_scope_token() {
227 assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
228
229 assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
230 }
231
232 #[test]
233 fn parse_scope() {
234 let scope = Scope::from_str("openid profile address").unwrap();
235 assert_eq!(scope.len(), 3);
236 assert!(scope.contains("openid"));
237 assert!(scope.contains("profile"));
238 assert!(scope.contains("address"));
239 assert!(!scope.contains("unknown"));
240
241 assert!(
242 Scope::from_str("").is_err(),
243 "there should always be at least one token in the scope"
244 );
245
246 assert!(Scope::from_str("invalid\\scope").is_err());
247 assert!(Scope::from_str("no double space").is_err());
248 assert!(Scope::from_str(" no leading space").is_err());
249 assert!(Scope::from_str("no trailing space ").is_err());
250
251 let scope = Scope::from_str("openid").unwrap();
252 assert_eq!(scope.len(), 1);
253 assert!(scope.contains("openid"));
254 assert!(!scope.contains("profile"));
255 assert!(!scope.contains("address"));
256
257 assert_eq!(
258 Scope::from_str("order does not matter"),
259 Scope::from_str("matter not order does"),
260 );
261
262 assert!(Scope::from_str("http://example.com").is_ok());
263 assert!(Scope::from_str("urn:matrix:client:api:*").is_ok());
264 assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:api:*").is_ok());
265 }
266}