1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{BrowserSession, Client, Session, SessionState, User};
12use mas_storage::{
13 Clock, Page, Pagination,
14 oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
15};
16use oauth2_types::scope::{Scope, ScopeToken};
17use rand::RngCore;
18use sea_query::{
19 Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
20 extension::postgres::PgExpr,
21};
22use sea_query_binder::SqlxBinder;
23use sqlx::PgConnection;
24use ulid::Ulid;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::{Filter, StatementExt},
30 iden::{OAuth2Clients, OAuth2Sessions},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgOAuth2SessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgOAuth2SessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[derive(sqlx::FromRow)]
49#[enum_def]
50struct OAuthSessionLookup {
51 oauth2_session_id: Uuid,
52 user_id: Option<Uuid>,
53 user_session_id: Option<Uuid>,
54 oauth2_client_id: Uuid,
55 scope_list: Vec<String>,
56 created_at: DateTime<Utc>,
57 finished_at: Option<DateTime<Utc>>,
58 user_agent: Option<String>,
59 last_active_at: Option<DateTime<Utc>>,
60 last_active_ip: Option<IpAddr>,
61 human_name: Option<String>,
62}
63
64impl TryFrom<OAuthSessionLookup> for Session {
65 type Error = DatabaseInconsistencyError;
66
67 fn try_from(value: OAuthSessionLookup) -> Result<Self, Self::Error> {
68 let id = Ulid::from(value.oauth2_session_id);
69 let scope: Result<Scope, _> = value
70 .scope_list
71 .iter()
72 .map(|s| s.parse::<ScopeToken>())
73 .collect();
74 let scope = scope.map_err(|e| {
75 DatabaseInconsistencyError::on("oauth2_sessions")
76 .column("scope")
77 .row(id)
78 .source(e)
79 })?;
80
81 let state = match value.finished_at {
82 None => SessionState::Valid,
83 Some(finished_at) => SessionState::Finished { finished_at },
84 };
85
86 Ok(Session {
87 id,
88 state,
89 created_at: value.created_at,
90 client_id: value.oauth2_client_id.into(),
91 user_id: value.user_id.map(Ulid::from),
92 user_session_id: value.user_session_id.map(Ulid::from),
93 scope,
94 user_agent: value.user_agent,
95 last_active_at: value.last_active_at,
96 last_active_ip: value.last_active_ip,
97 human_name: value.human_name,
98 })
99 }
100}
101
102impl Filter for OAuth2SessionFilter<'_> {
103 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
104 sea_query::Condition::all()
105 .add_option(self.user().map(|user| {
106 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).eq(Uuid::from(user.id))
107 }))
108 .add_option(self.client().map(|client| {
109 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
110 .eq(Uuid::from(client.id))
111 }))
112 .add_option(self.client_kind().map(|client_kind| {
113 let static_clients = Query::select()
117 .expr(Expr::col((
118 OAuth2Clients::Table,
119 OAuth2Clients::OAuth2ClientId,
120 )))
121 .and_where(Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).into())
122 .from(OAuth2Clients::Table)
123 .take();
124 if client_kind.is_static() {
125 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
126 .eq(Expr::any(static_clients))
127 } else {
128 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
129 .ne(Expr::all(static_clients))
130 }
131 }))
132 .add_option(self.device().map(|device| -> SimpleExpr {
133 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
134 Condition::any()
135 .add(
136 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
137 OAuth2Sessions::Table,
138 OAuth2Sessions::ScopeList,
139 )))),
140 )
141 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
142 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
143 )))
144 .into()
145 } else {
146 Expr::val(false).into()
148 }
149 }))
150 .add_option(self.browser_session().map(|browser_session| {
151 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))
152 .eq(Uuid::from(browser_session.id))
153 }))
154 .add_option(self.state().map(|state| {
155 if state.is_active() {
156 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_null()
157 } else {
158 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)).is_not_null()
159 }
160 }))
161 .add_option(self.scope().map(|scope| {
162 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
163 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)).contains(scope)
164 }))
165 .add_option(self.any_user().map(|any_user| {
166 if any_user {
167 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_not_null()
168 } else {
169 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)).is_null()
170 }
171 }))
172 .add_option(self.last_active_after().map(|last_active_after| {
173 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
174 .gt(last_active_after)
175 }))
176 .add_option(self.last_active_before().map(|last_active_before| {
177 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt))
178 .lt(last_active_before)
179 }))
180 }
181}
182
183#[async_trait]
184impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> {
185 type Error = DatabaseError;
186
187 #[tracing::instrument(
188 name = "db.oauth2_session.lookup",
189 skip_all,
190 fields(
191 db.query.text,
192 session.id = %id,
193 ),
194 err,
195 )]
196 async fn lookup(&mut self, id: Ulid) -> Result<Option<Session>, Self::Error> {
197 let res = sqlx::query_as!(
198 OAuthSessionLookup,
199 r#"
200 SELECT oauth2_session_id
201 , user_id
202 , user_session_id
203 , oauth2_client_id
204 , scope_list
205 , created_at
206 , finished_at
207 , user_agent
208 , last_active_at
209 , last_active_ip as "last_active_ip: IpAddr"
210 , human_name
211 FROM oauth2_sessions
212
213 WHERE oauth2_session_id = $1
214 "#,
215 Uuid::from(id),
216 )
217 .traced()
218 .fetch_optional(&mut *self.conn)
219 .await?;
220
221 let Some(session) = res else { return Ok(None) };
222
223 Ok(Some(session.try_into()?))
224 }
225
226 #[tracing::instrument(
227 name = "db.oauth2_session.add",
228 skip_all,
229 fields(
230 db.query.text,
231 %client.id,
232 session.id,
233 session.scope = %scope,
234 ),
235 err,
236 )]
237 async fn add(
238 &mut self,
239 rng: &mut (dyn RngCore + Send),
240 clock: &dyn Clock,
241 client: &Client,
242 user: Option<&User>,
243 user_session: Option<&BrowserSession>,
244 scope: Scope,
245 ) -> Result<Session, Self::Error> {
246 let created_at = clock.now();
247 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
248 tracing::Span::current().record("session.id", tracing::field::display(id));
249
250 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
251
252 sqlx::query!(
253 r#"
254 INSERT INTO oauth2_sessions
255 ( oauth2_session_id
256 , user_id
257 , user_session_id
258 , oauth2_client_id
259 , scope_list
260 , created_at
261 )
262 VALUES ($1, $2, $3, $4, $5, $6)
263 "#,
264 Uuid::from(id),
265 user.map(|u| Uuid::from(u.id)),
266 user_session.map(|s| Uuid::from(s.id)),
267 Uuid::from(client.id),
268 &scope_list,
269 created_at,
270 )
271 .traced()
272 .execute(&mut *self.conn)
273 .await?;
274
275 Ok(Session {
276 id,
277 state: SessionState::Valid,
278 created_at,
279 user_id: user.map(|u| u.id),
280 user_session_id: user_session.map(|s| s.id),
281 client_id: client.id,
282 scope,
283 user_agent: None,
284 last_active_at: None,
285 last_active_ip: None,
286 human_name: None,
287 })
288 }
289
290 #[tracing::instrument(
291 name = "db.oauth2_session.finish_bulk",
292 skip_all,
293 fields(
294 db.query.text,
295 ),
296 err,
297 )]
298 async fn finish_bulk(
299 &mut self,
300 clock: &dyn Clock,
301 filter: OAuth2SessionFilter<'_>,
302 ) -> Result<usize, Self::Error> {
303 let finished_at = clock.now();
304 let (sql, arguments) = Query::update()
305 .table(OAuth2Sessions::Table)
306 .value(OAuth2Sessions::FinishedAt, finished_at)
307 .apply_filter(filter)
308 .build_sqlx(PostgresQueryBuilder);
309
310 let res = sqlx::query_with(&sql, arguments)
311 .traced()
312 .execute(&mut *self.conn)
313 .await?;
314
315 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
316 }
317
318 #[tracing::instrument(
319 name = "db.oauth2_session.finish",
320 skip_all,
321 fields(
322 db.query.text,
323 %session.id,
324 %session.scope,
325 client.id = %session.client_id,
326 ),
327 err,
328 )]
329 async fn finish(
330 &mut self,
331 clock: &dyn Clock,
332 session: Session,
333 ) -> Result<Session, Self::Error> {
334 let finished_at = clock.now();
335 let res = sqlx::query!(
336 r#"
337 UPDATE oauth2_sessions
338 SET finished_at = $2
339 WHERE oauth2_session_id = $1
340 "#,
341 Uuid::from(session.id),
342 finished_at,
343 )
344 .traced()
345 .execute(&mut *self.conn)
346 .await?;
347
348 DatabaseError::ensure_affected_rows(&res, 1)?;
349
350 session
351 .finish(finished_at)
352 .map_err(DatabaseError::to_invalid_operation)
353 }
354
355 #[tracing::instrument(
356 name = "db.oauth2_session.list",
357 skip_all,
358 fields(
359 db.query.text,
360 ),
361 err,
362 )]
363 async fn list(
364 &mut self,
365 filter: OAuth2SessionFilter<'_>,
366 pagination: Pagination,
367 ) -> Result<Page<Session>, Self::Error> {
368 let (sql, arguments) = Query::select()
369 .expr_as(
370 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)),
371 OAuthSessionLookupIden::Oauth2SessionId,
372 )
373 .expr_as(
374 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId)),
375 OAuthSessionLookupIden::UserId,
376 )
377 .expr_as(
378 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId)),
379 OAuthSessionLookupIden::UserSessionId,
380 )
381 .expr_as(
382 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId)),
383 OAuthSessionLookupIden::Oauth2ClientId,
384 )
385 .expr_as(
386 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::ScopeList)),
387 OAuthSessionLookupIden::ScopeList,
388 )
389 .expr_as(
390 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::CreatedAt)),
391 OAuthSessionLookupIden::CreatedAt,
392 )
393 .expr_as(
394 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt)),
395 OAuthSessionLookupIden::FinishedAt,
396 )
397 .expr_as(
398 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserAgent)),
399 OAuthSessionLookupIden::UserAgent,
400 )
401 .expr_as(
402 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveAt)),
403 OAuthSessionLookupIden::LastActiveAt,
404 )
405 .expr_as(
406 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::LastActiveIp)),
407 OAuthSessionLookupIden::LastActiveIp,
408 )
409 .expr_as(
410 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::HumanName)),
411 OAuthSessionLookupIden::HumanName,
412 )
413 .from(OAuth2Sessions::Table)
414 .apply_filter(filter)
415 .generate_pagination(
416 (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
417 pagination,
418 )
419 .build_sqlx(PostgresQueryBuilder);
420
421 let edges: Vec<OAuthSessionLookup> = sqlx::query_as_with(&sql, arguments)
422 .traced()
423 .fetch_all(&mut *self.conn)
424 .await?;
425
426 let page = pagination.process(edges).try_map(Session::try_from)?;
427
428 Ok(page)
429 }
430
431 #[tracing::instrument(
432 name = "db.oauth2_session.count",
433 skip_all,
434 fields(
435 db.query.text,
436 ),
437 err,
438 )]
439 async fn count(&mut self, filter: OAuth2SessionFilter<'_>) -> Result<usize, Self::Error> {
440 let (sql, arguments) = Query::select()
441 .expr(Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId)).count())
442 .from(OAuth2Sessions::Table)
443 .apply_filter(filter)
444 .build_sqlx(PostgresQueryBuilder);
445
446 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
447 .traced()
448 .fetch_one(&mut *self.conn)
449 .await?;
450
451 count
452 .try_into()
453 .map_err(DatabaseError::to_invalid_operation)
454 }
455
456 #[tracing::instrument(
457 name = "db.oauth2_session.record_batch_activity",
458 skip_all,
459 fields(
460 db.query.text,
461 ),
462 err,
463 )]
464 async fn record_batch_activity(
465 &mut self,
466 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
467 ) -> Result<(), Self::Error> {
468 activities.sort_unstable();
471 let mut ids = Vec::with_capacity(activities.len());
472 let mut last_activities = Vec::with_capacity(activities.len());
473 let mut ips = Vec::with_capacity(activities.len());
474
475 for (id, last_activity, ip) in activities {
476 ids.push(Uuid::from(id));
477 last_activities.push(last_activity);
478 ips.push(ip);
479 }
480
481 let res = sqlx::query!(
482 r#"
483 UPDATE oauth2_sessions
484 SET last_active_at = GREATEST(t.last_active_at, oauth2_sessions.last_active_at)
485 , last_active_ip = COALESCE(t.last_active_ip, oauth2_sessions.last_active_ip)
486 FROM (
487 SELECT *
488 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
489 AS t(oauth2_session_id, last_active_at, last_active_ip)
490 ) AS t
491 WHERE oauth2_sessions.oauth2_session_id = t.oauth2_session_id
492 "#,
493 &ids,
494 &last_activities,
495 &ips as &[Option<IpAddr>],
496 )
497 .traced()
498 .execute(&mut *self.conn)
499 .await?;
500
501 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
502
503 Ok(())
504 }
505
506 #[tracing::instrument(
507 name = "db.oauth2_session.record_user_agent",
508 skip_all,
509 fields(
510 db.query.text,
511 %session.id,
512 %session.scope,
513 client.id = %session.client_id,
514 session.user_agent = user_agent,
515 ),
516 err,
517 )]
518 async fn record_user_agent(
519 &mut self,
520 mut session: Session,
521 user_agent: String,
522 ) -> Result<Session, Self::Error> {
523 let res = sqlx::query!(
524 r#"
525 UPDATE oauth2_sessions
526 SET user_agent = $2
527 WHERE oauth2_session_id = $1
528 "#,
529 Uuid::from(session.id),
530 &*user_agent,
531 )
532 .traced()
533 .execute(&mut *self.conn)
534 .await?;
535
536 session.user_agent = Some(user_agent);
537
538 DatabaseError::ensure_affected_rows(&res, 1)?;
539
540 Ok(session)
541 }
542
543 #[tracing::instrument(
544 name = "repository.oauth2_session.set_human_name",
545 skip(self),
546 fields(
547 client.id = %session.client_id,
548 session.human_name = ?human_name,
549 ),
550 err,
551 )]
552 async fn set_human_name(
553 &mut self,
554 mut session: Session,
555 human_name: Option<String>,
556 ) -> Result<Session, Self::Error> {
557 let res = sqlx::query!(
558 r#"
559 UPDATE oauth2_sessions
560 SET human_name = $2
561 WHERE oauth2_session_id = $1
562 "#,
563 Uuid::from(session.id),
564 human_name.as_deref(),
565 )
566 .traced()
567 .execute(&mut *self.conn)
568 .await?;
569
570 session.human_name = human_name;
571
572 DatabaseError::ensure_affected_rows(&res, 1)?;
573
574 Ok(session)
575 }
576}