From 0fa6536c7ace20b6b6f0150b2b2a6ea27c8ce743 Mon Sep 17 00:00:00 2001 From: appflowy Date: Thu, 16 Mar 2023 11:23:42 +0800 Subject: [PATCH] refactor: use i64 as user id --- Cargo.lock | 1 + Cargo.toml | 2 + crates/revdb/src/db.rs | 32 +-------- crates/snowflake/src/lib.rs | 30 +++++--- migrations/20230312043023_user.sql | 11 +-- src/api/user.rs | 20 +----- src/application.rs | 3 + src/component/auth/password.rs | 4 +- src/component/auth/user.rs | 111 ++++++++++++++--------------- src/state.rs | 9 ++- 10 files changed, 98 insertions(+), 125 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 52ec66d2..6786e6ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -461,6 +461,7 @@ dependencies = [ "serde", "serde-aux", "serde_json", + "snowflake", "sqlx", "thiserror", "token", diff --git a/Cargo.toml b/Cargo.toml index 43daccac..6c38cccc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,9 @@ tracing-actix-web = "0.7" tracing-log = "0.1.1" sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-rustls", "macros", "postgres", "uuid", "chrono", "migrate", "offline"] } +#Local crate token = { path = "./crates/token" } +snowflake = { path = "./crates/snowflake" } [dev-dependencies] once_cell = "1.7.2" diff --git a/crates/revdb/src/db.rs b/crates/revdb/src/db.rs index cfbdd213..bfd4df92 100644 --- a/crates/revdb/src/db.rs +++ b/crates/revdb/src/db.rs @@ -24,6 +24,7 @@ impl RevDB { Ok(value.map(|value| value.to_vec())) } } + // Optimize your data layout: Sled's B-Tree implementation works best when the keys are sequential, // so try to organize the data in a way that maximizes sequential access. fn make_seq_key(uid: i64, rev_id: i64) -> [u8; 16] { @@ -32,34 +33,3 @@ fn make_seq_key(uid: i64, rev_id: i64) -> [u8; 16] { key[8..16].copy_from_slice(&rev_id.to_be_bytes()); key } - -#[cfg(test)] -mod tests { - use crate::db::RevDB; - use std::path::Path; - use std::time::Instant; - - #[test] - fn insert_speed() { - let path = Path::new("."); - let db = RevDB::open(path).unwrap(); - let start_time = Instant::now(); - - for i in 0..=100000 { - db.insert(1, i, b"hello world").unwrap(); - } - - for i in 0..=100000 { - db.get(1, i).unwrap(); - } - - let end_time = Instant::now(); - let elapsed_time = end_time - start_time; - // Print the elapsed time in seconds and milliseconds - println!( - "Elapsed time: {}s, {}ms", - elapsed_time.as_secs(), - elapsed_time.subsec_millis() - ); - } -} diff --git a/crates/snowflake/src/lib.rs b/crates/snowflake/src/lib.rs index 8a63d8f1..b35b209c 100644 --- a/crates/snowflake/src/lib.rs +++ b/crates/snowflake/src/lib.rs @@ -1,14 +1,13 @@ -use std::time::{Duration, SystemTime}; +use std::time::SystemTime; -const EPOCH: u64 = 1420070400000; +const EPOCH: u64 = 1637806706000; const NODE_ID_BITS: u64 = 10; const SEQUENCE_BITS: u64 = 12; const NODE_ID_SHIFT: u64 = SEQUENCE_BITS; const TIMESTAMP_SHIFT: u64 = NODE_ID_BITS + SEQUENCE_BITS; const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1; -const MAX_NODE_ID: u64 = (1 << NODE_ID_BITS) - 1; -struct Snowflake { +pub struct Snowflake { node_id: u64, sequence: u64, last_timestamp: u64, @@ -23,7 +22,7 @@ impl Snowflake { } } - pub fn next_id(&mut self) -> u64 { + pub fn next_id(&mut self) -> i64 { let timestamp = self.timestamp(); if timestamp < self.last_timestamp { panic!("Clock moved backwards!"); @@ -39,7 +38,9 @@ impl Snowflake { } self.last_timestamp = timestamp; - (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence + let id = + (timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence; + id as i64 } fn wait_next_millis(&self) { @@ -52,12 +53,21 @@ impl Snowflake { fn timestamp(&self) -> u64 { SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() + .expect("Clock moved backwards!") .as_millis() as u64 } } -fn main() { - let mut snowflake = Snowflake::new(1); - println!("{}", snowflake.next_id()); +#[cfg(test)] +mod tests { + use crate::Snowflake; + + #[test] + fn gen_id() { + let mut snow_flake = Snowflake::new(1); + let id_1 = snow_flake.next_id(); + let id_2 = snow_flake.next_id(); + + assert_ne!(id_1, id_2); + } } diff --git a/migrations/20230312043023_user.sql b/migrations/20230312043023_user.sql index f9c556e7..2d3a5299 100644 --- a/migrations/20230312043023_user.sql +++ b/migrations/20230312043023_user.sql @@ -1,8 +1,9 @@ -- Add migration script here -CREATE TABLE users ( - uid uuid PRIMARY KEY, - username TEXT NOT NULL, - password TEXT NOT NULL, - email TEXT NOT NULL UNIQUE, +CREATE TABLE users +( + uid bigint PRIMARY KEY, + username TEXT NOT NULL, + password TEXT NOT NULL, + email TEXT NOT NULL UNIQUE, create_time timestamptz NOT NULL ); \ No newline at end of file diff --git a/src/api/user.rs b/src/api/user.rs index 8cb03951..1bcddfe7 100644 --- a/src/api/user.rs +++ b/src/api/user.rs @@ -30,14 +30,7 @@ async fn login_handler( let password = UserPassword::parse(req.password) .map_err(|_| InputParamsError::InvalidPassword)? .0; - let (resp, token) = login( - state.pg_pool.clone(), - state.user.clone(), - email, - password, - &state.config.application.server_key, - ) - .await?; + let (resp, token) = login(email, password, &state).await?; // Renews the session key, assigning existing session state to new key. session.renew(); @@ -68,16 +61,7 @@ async fn register_handler(req: Json, state: Data) -> Res .map_err(|_| InputParamsError::InvalidPassword)? .0; - let resp = register( - state.pg_pool.clone(), - state.user.clone(), - name, - email, - password, - &state.config.application.server_key, - ) - .await?; - + let resp = register(name, email, password, &state).await?; Ok(HttpResponse::Ok().json(resp)) } diff --git a/src/application.rs b/src/application.rs index 8f596fec..f15906f6 100644 --- a/src/application.rs +++ b/src/application.rs @@ -13,9 +13,11 @@ use actix_web::{dev::Server, web, web::Data, App, HttpServer}; use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod}; use openssl::x509::X509; use secrecy::{ExposeSecret, Secret}; +use snowflake::Snowflake; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::net::TcpListener; use std::sync::Arc; +use tokio::sync::RwLock; use tracing_actix_web::TracingLogger; pub struct Application { @@ -96,6 +98,7 @@ pub async fn init_state(config: &Config) -> State { pg_pool, config: Arc::new(config.clone()), user: Arc::new(Default::default()), + id_gen: Arc::new(RwLock::new(Snowflake::new(1))), } } diff --git a/src/component/auth/password.rs b/src/component/auth/password.rs index 38152711..39878664 100644 --- a/src/component/auth/password.rs +++ b/src/component/auth/password.rs @@ -16,7 +16,7 @@ pub struct Credentials { pub async fn validate_credentials( credentials: Credentials, pool: &PgPool, -) -> Result { +) -> Result { let mut uid = None; let mut expected_hash_password = Secret::new( "$argon2id$v=19$m=15000,t=2,p=1$\ @@ -58,7 +58,7 @@ pub fn compute_hash_password(password: &[u8]) -> Result, anyhow:: async fn get_stored_credentials( email: &str, pool: &PgPool, -) -> Result)>, anyhow::Error> { +) -> Result)>, anyhow::Error> { let row = sqlx::query!( r#" SELECT uid, password diff --git a/src/component/auth/user.rs b/src/component/auth/user.rs index f5f635b5..f4936eb5 100644 --- a/src/component/auth/user.rs +++ b/src/component/auth/user.rs @@ -2,51 +2,46 @@ use crate::component::auth::{ compute_hash_password, internal_error, validate_credentials, AuthError, Credentials, }; use crate::config::env::domain; -use crate::state::UserCache; +use crate::state::{State, UserCache}; use crate::telemetry::spawn_blocking_with_tracing; use actix_web::HttpRequest; use anyhow::Context; use chrono::Duration; use chrono::Utc; -use secrecy::{ExposeSecret, Secret, Zeroize}; +use secrecy::zeroize::DefaultIsZeroes; +use secrecy::{CloneableSecret, DebugSecret, ExposeSecret, Secret, Zeroize}; use serde::{Deserialize, Serialize}; -use sqlx::types::uuid; use sqlx::{PgPool, Postgres, Transaction}; use std::sync::Arc; use token::{create_token, parse_token, TokenError}; use tokio::sync::RwLock; pub async fn login( - pg_pool: PgPool, - cache: Arc>, email: String, password: String, - server_key: &Secret, + state: &State, ) -> Result<(LoginResponse, Secret), AuthError> { let credentials = Credentials { email, password: Secret::new(password), }; + let server_key = &state.config.application.server_key; - match validate_credentials(credentials, &pg_pool).await { + match validate_credentials(credentials, &state.pg_pool).await { Ok(uid) => { - let uid = uid.to_string(); - let token = Token::create_token(&uid, server_key)?; - let logged_user = LoggedUser::new(uid.clone()); - cache.write().await.authorized(logged_user); + let token = Token::create_token(uid, server_key)?; + let logged_user = LoggedUser::new(uid); + state.user.write().await.authorized(logged_user); Ok(( LoginResponse { token: token.clone().into(), - uid, + uid: uid.to_string(), }, Secret::new(token), )) } - Err(err) => { - // - Err(err) - } + Err(err) => Err(err), } } @@ -55,13 +50,13 @@ pub async fn logout(logged_user: LoggedUser, cache: Arc>) { } pub async fn register( - pg_pool: PgPool, - cache: Arc>, username: String, email: String, password: String, - server_key: &Secret, + state: &State, ) -> Result { + let pg_pool = state.pg_pool.clone(); + let server_key = &state.config.application.server_key; let mut transaction = pg_pool .begin() .await @@ -75,15 +70,15 @@ pub async fn register( return Err(AuthError::UserAlreadyExist { email }); } - let uuid = uuid::Uuid::new_v4(); - let token = Token::create_token(&uuid.to_string(), server_key)?; + let uid = state.id_gen.write().await.next_id(); + let token = Token::create_token(uid, server_key)?; let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?; let _ = sqlx::query!( r#" INSERT INTO users (uid, email, username, create_time, password) VALUES ($1, $2, $3, $4, $5) "#, - uuid, + uid, email, username, Utc::now(), @@ -100,8 +95,8 @@ pub async fn register( .context("Failed to commit SQL transaction to register user.") .map_err(internal_error)?; - let logged_user = LoggedUser::new(uuid.to_string()); - cache.write().await.authorized(logged_user); + let logged_user = LoggedUser::new(uid); + state.user.write().await.authorized(logged_user); Ok(RegisterResponse { token: token.into(), @@ -120,7 +115,7 @@ pub async fn change_password( .context("Failed to acquire a Postgres connection to change password") .map_err(internal_error)?; - let email = get_user_email(logged_user.expose_secret(), &mut transaction).await?; + let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?; // check password let credentials = Credentials { @@ -135,15 +130,11 @@ pub async fn change_password( .await .context("Failed to hash password")??; - let uid = - uuid::Uuid::parse_str(logged_user.expose_secret()).map_err(|e| AuthError::InvalidUuid { - err: format!("{}", e), - })?; // Save password to disk let sql = "UPDATE users SET password = $1 where uid = $2"; let _ = sqlx::query(sql) .bind(new_hash_password.expose_secret()) - .bind(uid) + .bind(logged_user.expose_secret()) .execute(&mut transaction) .await .context("Failed to change user's password in the database.")?; @@ -157,10 +148,9 @@ pub async fn change_password( } pub async fn get_user_email( - uid: &str, + uid: i64, transaction: &mut Transaction<'_, Postgres>, ) -> Result { - let uid = uuid::Uuid::parse_str(uid)?; let row = sqlx::query!( r#" SELECT email @@ -219,37 +209,47 @@ pub struct ChangePasswordRequest { pub new_password_confirm: String, } -#[derive(Debug, Clone)] -pub struct LoggedUser { - uid: Secret, +#[derive(Clone)] +pub struct WrapI64(i64); +impl Default for WrapI64 { + fn default() -> Self { + Self(0) + } } +impl Copy for WrapI64 {} +impl DefaultIsZeroes for WrapI64 {} +impl DebugSecret for WrapI64 {} +impl CloneableSecret for WrapI64 {} + +impl std::ops::Deref for WrapI64 { + type Target = i64; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub struct LoggedUser(Secret); impl From for LoggedUser { fn from(c: Claim) -> Self { - Self { - uid: Secret::new(c.uid), - } + Self(Secret::new(WrapI64(c.uid))) } } impl LoggedUser { - pub fn new(uid: String) -> Self { - Self { - uid: Secret::new(uid), - } + pub fn new(uid: i64) -> Self { + Self(Secret::new(WrapI64(uid))) } pub fn from_token(server_key: &Secret, token: &str) -> Result { let user: LoggedUser = Token::decode_token(server_key, token)?.into(); Ok(user) } -} -impl std::ops::Deref for LoggedUser { - type Target = Secret; - - fn deref(&self) -> &Self::Target { - &self.uid + pub fn expose_secret(&self) -> &i64 { + self.0.expose_secret() } } @@ -259,15 +259,12 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30; #[derive(Debug, Serialize, Deserialize)] pub struct Claim { iss: String, - uid: String, + uid: i64, } impl Claim { - pub fn with_user_id(uid: &str) -> Self { - Self { - iss: domain(), - uid: uid.to_string(), - } + pub fn with_user_id(uid: i64) -> Self { + Self { iss: domain(), uid } } } @@ -281,8 +278,8 @@ impl Zeroize for Token { } impl Token { - pub fn create_token(user_id: &str, server_key: &Secret) -> Result { - let claim = Claim::with_user_id(user_id); + pub fn create_token(uid: i64, server_key: &Secret) -> Result { + let claim = Claim::with_user_id(uid); let token = create_token( server_key.expose_secret().as_str(), claim, @@ -317,7 +314,7 @@ pub fn logged_user_from_request( pub fn uid_from_request( request: &HttpRequest, server_key: &Secret, -) -> Result, AuthError> { +) -> Result, AuthError> { match request.headers().get(HEADER_TOKEN) { Some(header) => match header.to_str() { Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)), diff --git a/src/state.rs b/src/state.rs index 23183941..f2e620c9 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,7 +1,7 @@ use crate::component::auth::LoggedUser; use crate::config::config::Config; use chrono::{DateTime, Utc}; -use secrecy::ExposeSecret; +use snowflake::Snowflake; use sqlx::PgPool; use std::collections::BTreeMap; use std::sync::Arc; @@ -12,12 +12,17 @@ pub struct State { pub pg_pool: PgPool, pub config: Arc, pub user: Arc>, + pub id_gen: Arc>, } impl State { pub async fn load_users(_pool: &PgPool) { todo!() } + + pub async fn next_user_id(&self) -> i64 { + self.id_gen.write().await.next_id() + } } #[derive(Clone, Debug, Copy)] @@ -31,7 +36,7 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30; #[derive(Debug, Default)] pub struct UserCache { // Keep track the user authentication state - user: BTreeMap, + user: BTreeMap, } impl UserCache {