oauth2_types/
scope.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! Types to define an [access token's scope].
8//!
9//! [access token's scope]: https://www.rfc-editor.org/rfc/rfc6749#section-3.3
10
11#![allow(clippy::module_name_repetitions)]
12
13use std::{
14    borrow::Cow,
15    collections::BTreeSet,
16    iter::FromIterator,
17    ops::{Deref, DerefMut},
18    str::FromStr,
19};
20
21use serde::{Deserialize, Serialize};
22use thiserror::Error;
23
24/// The error type returned when a scope is invalid.
25#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
26#[error("Invalid scope format")]
27pub struct InvalidScope;
28
29/// A scope token or scope value.
30#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
31pub struct ScopeToken(Cow<'static, str>);
32
33impl ScopeToken {
34    /// Create a `ScopeToken` from a static string. The validity of it is not
35    /// checked since it has to be valid in const contexts
36    #[must_use]
37    pub const fn from_static(token: &'static str) -> Self {
38        Self(Cow::Borrowed(token))
39    }
40
41    /// Get the scope token as a string slice.
42    #[must_use]
43    pub fn as_str(&self) -> &str {
44        self.0.as_ref()
45    }
46}
47
48/// `openid`.
49///
50/// Must be included in OpenID Connect requests.
51pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
52
53/// `profile`.
54///
55/// Requests access to the End-User's default profile Claims.
56pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
57
58/// `email`.
59///
60/// Requests access to the `email` and `email_verified` Claims.
61pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
62
63/// `address`.
64///
65/// Requests access to the `address` Claim.
66pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
67
68/// `phone`.
69///
70/// Requests access to the `phone_number` and `phone_number_verified` Claims.
71pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
72
73/// `offline_access`.
74///
75/// Requests that an OAuth 2.0 Refresh Token be issued that can be used to
76/// obtain an Access Token that grants access to the End-User's Userinfo
77/// Endpoint even when the End-User is not present (not logged in).
78pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
79
80// As per RFC6749 appendix A:
81// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A
82//
83//    NQCHAR     = %x21 / %x23-5B / %x5D-7E
84fn nqchar(c: char) -> bool {
85    '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
86}
87
88impl FromStr for ScopeToken {
89    type Err = InvalidScope;
90
91    fn from_str(s: &str) -> Result<Self, Self::Err> {
92        // As per RFC6749 appendix A.4:
93        // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
94        //
95        //    scope-token = 1*NQCHAR
96        if !s.is_empty() && s.chars().all(nqchar) {
97            Ok(ScopeToken(Cow::Owned(s.into())))
98        } else {
99            Err(InvalidScope)
100        }
101    }
102}
103
104impl Deref for ScopeToken {
105    type Target = str;
106
107    fn deref(&self) -> &Self::Target {
108        &self.0
109    }
110}
111
112impl std::fmt::Display for ScopeToken {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        self.0.fmt(f)
115    }
116}
117
118/// A scope.
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct Scope(BTreeSet<ScopeToken>);
121
122impl Deref for Scope {
123    type Target = BTreeSet<ScopeToken>;
124
125    fn deref(&self) -> &Self::Target {
126        &self.0
127    }
128}
129
130impl DerefMut for Scope {
131    fn deref_mut(&mut self) -> &mut Self::Target {
132        &mut self.0
133    }
134}
135
136impl FromStr for Scope {
137    type Err = InvalidScope;
138
139    fn from_str(s: &str) -> Result<Self, Self::Err> {
140        // As per RFC6749 appendix A.4:
141        // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
142        //
143        //    scope       = scope-token *( SP scope-token )
144        let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
145            s.split(' ').map(ScopeToken::from_str).collect();
146
147        Ok(Self(scopes?))
148    }
149}
150
151impl Scope {
152    /// Whether this `Scope` is empty.
153    #[must_use]
154    pub fn is_empty(&self) -> bool {
155        // This should never be the case?
156        self.0.is_empty()
157    }
158
159    /// The number of tokens in the `Scope`.
160    #[must_use]
161    pub fn len(&self) -> usize {
162        self.0.len()
163    }
164
165    /// Whether this `Scope` contains the given value.
166    #[must_use]
167    pub fn contains(&self, token: &str) -> bool {
168        ScopeToken::from_str(token)
169            .map(|token| self.0.contains(&token))
170            .unwrap_or(false)
171    }
172
173    /// Inserts the given token in this `Scope`.
174    ///
175    /// Returns whether the token was newly inserted.
176    pub fn insert(&mut self, value: ScopeToken) -> bool {
177        self.0.insert(value)
178    }
179}
180
181impl std::fmt::Display for Scope {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        for (index, token) in self.0.iter().enumerate() {
184            if index == 0 {
185                write!(f, "{token}")?;
186            } else {
187                write!(f, " {token}")?;
188            }
189        }
190
191        Ok(())
192    }
193}
194
195impl Serialize for Scope {
196    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
197    where
198        S: serde::Serializer,
199    {
200        self.to_string().serialize(serializer)
201    }
202}
203
204impl<'de> Deserialize<'de> for Scope {
205    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
206    where
207        D: serde::Deserializer<'de>,
208    {
209        // FIXME: seems like there is an unnecessary clone here?
210        let scope: String = Deserialize::deserialize(deserializer)?;
211        Scope::from_str(&scope).map_err(serde::de::Error::custom)
212    }
213}
214
215impl FromIterator<ScopeToken> for Scope {
216    fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
217        Self(BTreeSet::from_iter(iter))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn parse_scope_token() {
227        assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
228
229        assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
230    }
231
232    #[test]
233    fn parse_scope() {
234        let scope = Scope::from_str("openid profile address").unwrap();
235        assert_eq!(scope.len(), 3);
236        assert!(scope.contains("openid"));
237        assert!(scope.contains("profile"));
238        assert!(scope.contains("address"));
239        assert!(!scope.contains("unknown"));
240
241        assert!(
242            Scope::from_str("").is_err(),
243            "there should always be at least one token in the scope"
244        );
245
246        assert!(Scope::from_str("invalid\\scope").is_err());
247        assert!(Scope::from_str("no  double space").is_err());
248        assert!(Scope::from_str(" no leading space").is_err());
249        assert!(Scope::from_str("no trailing space ").is_err());
250
251        let scope = Scope::from_str("openid").unwrap();
252        assert_eq!(scope.len(), 1);
253        assert!(scope.contains("openid"));
254        assert!(!scope.contains("profile"));
255        assert!(!scope.contains("address"));
256
257        assert_eq!(
258            Scope::from_str("order does not matter"),
259            Scope::from_str("matter not order does"),
260        );
261
262        assert!(Scope::from_str("http://example.com").is_ok());
263        assert!(Scope::from_str("urn:matrix:client:api:*").is_ok());
264        assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:api:*").is_ok());
265    }
266}