refactor: use i64 as user id

This commit is contained in:
appflowy 2023-03-16 11:23:42 +08:00
parent 277b54711c
commit 0fa6536c7a
10 changed files with 98 additions and 125 deletions

1
Cargo.lock generated
View File

@ -461,6 +461,7 @@ dependencies = [
"serde",
"serde-aux",
"serde_json",
"snowflake",
"sqlx",
"thiserror",
"token",

View File

@ -60,7 +60,9 @@ tracing-actix-web = "0.7"
tracing-log = "0.1.1"
sqlx = { version = "0.6", default-features = false, features = ["runtime-actix-rustls", "macros", "postgres", "uuid", "chrono", "migrate", "offline"] }
#Local crate
token = { path = "./crates/token" }
snowflake = { path = "./crates/snowflake" }
[dev-dependencies]
once_cell = "1.7.2"

View File

@ -24,6 +24,7 @@ impl RevDB {
Ok(value.map(|value| value.to_vec()))
}
}
// Optimize your data layout: Sled's B-Tree implementation works best when the keys are sequential,
// so try to organize the data in a way that maximizes sequential access.
fn make_seq_key(uid: i64, rev_id: i64) -> [u8; 16] {
@ -32,34 +33,3 @@ fn make_seq_key(uid: i64, rev_id: i64) -> [u8; 16] {
key[8..16].copy_from_slice(&rev_id.to_be_bytes());
key
}
#[cfg(test)]
mod tests {
use crate::db::RevDB;
use std::path::Path;
use std::time::Instant;
#[test]
fn insert_speed() {
let path = Path::new(".");
let db = RevDB::open(path).unwrap();
let start_time = Instant::now();
for i in 0..=100000 {
db.insert(1, i, b"hello world").unwrap();
}
for i in 0..=100000 {
db.get(1, i).unwrap();
}
let end_time = Instant::now();
let elapsed_time = end_time - start_time;
// Print the elapsed time in seconds and milliseconds
println!(
"Elapsed time: {}s, {}ms",
elapsed_time.as_secs(),
elapsed_time.subsec_millis()
);
}
}

View File

@ -1,14 +1,13 @@
use std::time::{Duration, SystemTime};
use std::time::SystemTime;
const EPOCH: u64 = 1420070400000;
const EPOCH: u64 = 1637806706000;
const NODE_ID_BITS: u64 = 10;
const SEQUENCE_BITS: u64 = 12;
const NODE_ID_SHIFT: u64 = SEQUENCE_BITS;
const TIMESTAMP_SHIFT: u64 = NODE_ID_BITS + SEQUENCE_BITS;
const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
const MAX_NODE_ID: u64 = (1 << NODE_ID_BITS) - 1;
struct Snowflake {
pub struct Snowflake {
node_id: u64,
sequence: u64,
last_timestamp: u64,
@ -23,7 +22,7 @@ impl Snowflake {
}
}
pub fn next_id(&mut self) -> u64 {
pub fn next_id(&mut self) -> i64 {
let timestamp = self.timestamp();
if timestamp < self.last_timestamp {
panic!("Clock moved backwards!");
@ -39,7 +38,9 @@ impl Snowflake {
}
self.last_timestamp = timestamp;
(timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence
let id =
(timestamp - EPOCH) << TIMESTAMP_SHIFT | self.node_id << NODE_ID_SHIFT | self.sequence;
id as i64
}
fn wait_next_millis(&self) {
@ -52,12 +53,21 @@ impl Snowflake {
fn timestamp(&self) -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.expect("Clock moved backwards!")
.as_millis() as u64
}
}
fn main() {
let mut snowflake = Snowflake::new(1);
println!("{}", snowflake.next_id());
#[cfg(test)]
mod tests {
use crate::Snowflake;
#[test]
fn gen_id() {
let mut snow_flake = Snowflake::new(1);
let id_1 = snow_flake.next_id();
let id_2 = snow_flake.next_id();
assert_ne!(id_1, id_2);
}
}

View File

@ -1,8 +1,9 @@
-- Add migration script here
CREATE TABLE users (
uid uuid PRIMARY KEY,
username TEXT NOT NULL,
password TEXT NOT NULL,
email TEXT NOT NULL UNIQUE,
CREATE TABLE users
(
uid bigint PRIMARY KEY,
username TEXT NOT NULL,
password TEXT NOT NULL,
email TEXT NOT NULL UNIQUE,
create_time timestamptz NOT NULL
);

View File

@ -30,14 +30,7 @@ async fn login_handler(
let password = UserPassword::parse(req.password)
.map_err(|_| InputParamsError::InvalidPassword)?
.0;
let (resp, token) = login(
state.pg_pool.clone(),
state.user.clone(),
email,
password,
&state.config.application.server_key,
)
.await?;
let (resp, token) = login(email, password, &state).await?;
// Renews the session key, assigning existing session state to new key.
session.renew();
@ -68,16 +61,7 @@ async fn register_handler(req: Json<RegisterRequest>, state: Data<State>) -> Res
.map_err(|_| InputParamsError::InvalidPassword)?
.0;
let resp = register(
state.pg_pool.clone(),
state.user.clone(),
name,
email,
password,
&state.config.application.server_key,
)
.await?;
let resp = register(name, email, password, &state).await?;
Ok(HttpResponse::Ok().json(resp))
}

View File

@ -13,9 +13,11 @@ use actix_web::{dev::Server, web, web::Data, App, HttpServer};
use openssl::ssl::{SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod};
use openssl::x509::X509;
use secrecy::{ExposeSecret, Secret};
use snowflake::Snowflake;
use sqlx::{postgres::PgPoolOptions, PgPool};
use std::net::TcpListener;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing_actix_web::TracingLogger;
pub struct Application {
@ -96,6 +98,7 @@ pub async fn init_state(config: &Config) -> State {
pg_pool,
config: Arc::new(config.clone()),
user: Arc::new(Default::default()),
id_gen: Arc::new(RwLock::new(Snowflake::new(1))),
}
}

View File

@ -16,7 +16,7 @@ pub struct Credentials {
pub async fn validate_credentials(
credentials: Credentials,
pool: &PgPool,
) -> Result<uuid::Uuid, AuthError> {
) -> Result<i64, AuthError> {
let mut uid = None;
let mut expected_hash_password = Secret::new(
"$argon2id$v=19$m=15000,t=2,p=1$\
@ -58,7 +58,7 @@ pub fn compute_hash_password(password: &[u8]) -> Result<Secret<String>, anyhow::
async fn get_stored_credentials(
email: &str,
pool: &PgPool,
) -> Result<Option<(uuid::Uuid, Secret<String>)>, anyhow::Error> {
) -> Result<Option<(i64, Secret<String>)>, anyhow::Error> {
let row = sqlx::query!(
r#"
SELECT uid, password

View File

@ -2,51 +2,46 @@ use crate::component::auth::{
compute_hash_password, internal_error, validate_credentials, AuthError, Credentials,
};
use crate::config::env::domain;
use crate::state::UserCache;
use crate::state::{State, UserCache};
use crate::telemetry::spawn_blocking_with_tracing;
use actix_web::HttpRequest;
use anyhow::Context;
use chrono::Duration;
use chrono::Utc;
use secrecy::{ExposeSecret, Secret, Zeroize};
use secrecy::zeroize::DefaultIsZeroes;
use secrecy::{CloneableSecret, DebugSecret, ExposeSecret, Secret, Zeroize};
use serde::{Deserialize, Serialize};
use sqlx::types::uuid;
use sqlx::{PgPool, Postgres, Transaction};
use std::sync::Arc;
use token::{create_token, parse_token, TokenError};
use tokio::sync::RwLock;
pub async fn login(
pg_pool: PgPool,
cache: Arc<RwLock<UserCache>>,
email: String,
password: String,
server_key: &Secret<String>,
state: &State,
) -> Result<(LoginResponse, Secret<Token>), AuthError> {
let credentials = Credentials {
email,
password: Secret::new(password),
};
let server_key = &state.config.application.server_key;
match validate_credentials(credentials, &pg_pool).await {
match validate_credentials(credentials, &state.pg_pool).await {
Ok(uid) => {
let uid = uid.to_string();
let token = Token::create_token(&uid, server_key)?;
let logged_user = LoggedUser::new(uid.clone());
cache.write().await.authorized(logged_user);
let token = Token::create_token(uid, server_key)?;
let logged_user = LoggedUser::new(uid);
state.user.write().await.authorized(logged_user);
Ok((
LoginResponse {
token: token.clone().into(),
uid,
uid: uid.to_string(),
},
Secret::new(token),
))
}
Err(err) => {
//
Err(err)
}
Err(err) => Err(err),
}
}
@ -55,13 +50,13 @@ pub async fn logout(logged_user: LoggedUser, cache: Arc<RwLock<UserCache>>) {
}
pub async fn register(
pg_pool: PgPool,
cache: Arc<RwLock<UserCache>>,
username: String,
email: String,
password: String,
server_key: &Secret<String>,
state: &State,
) -> Result<RegisterResponse, AuthError> {
let pg_pool = state.pg_pool.clone();
let server_key = &state.config.application.server_key;
let mut transaction = pg_pool
.begin()
.await
@ -75,15 +70,15 @@ pub async fn register(
return Err(AuthError::UserAlreadyExist { email });
}
let uuid = uuid::Uuid::new_v4();
let token = Token::create_token(&uuid.to_string(), server_key)?;
let uid = state.id_gen.write().await.next_id();
let token = Token::create_token(uid, server_key)?;
let password = compute_hash_password(password.as_bytes()).map_err(internal_error)?;
let _ = sqlx::query!(
r#"
INSERT INTO users (uid, email, username, create_time, password)
VALUES ($1, $2, $3, $4, $5)
"#,
uuid,
uid,
email,
username,
Utc::now(),
@ -100,8 +95,8 @@ pub async fn register(
.context("Failed to commit SQL transaction to register user.")
.map_err(internal_error)?;
let logged_user = LoggedUser::new(uuid.to_string());
cache.write().await.authorized(logged_user);
let logged_user = LoggedUser::new(uid);
state.user.write().await.authorized(logged_user);
Ok(RegisterResponse {
token: token.into(),
@ -120,7 +115,7 @@ pub async fn change_password(
.context("Failed to acquire a Postgres connection to change password")
.map_err(internal_error)?;
let email = get_user_email(logged_user.expose_secret(), &mut transaction).await?;
let email = get_user_email(*logged_user.expose_secret(), &mut transaction).await?;
// check password
let credentials = Credentials {
@ -135,15 +130,11 @@ pub async fn change_password(
.await
.context("Failed to hash password")??;
let uid =
uuid::Uuid::parse_str(logged_user.expose_secret()).map_err(|e| AuthError::InvalidUuid {
err: format!("{}", e),
})?;
// Save password to disk
let sql = "UPDATE users SET password = $1 where uid = $2";
let _ = sqlx::query(sql)
.bind(new_hash_password.expose_secret())
.bind(uid)
.bind(logged_user.expose_secret())
.execute(&mut transaction)
.await
.context("Failed to change user's password in the database.")?;
@ -157,10 +148,9 @@ pub async fn change_password(
}
pub async fn get_user_email(
uid: &str,
uid: i64,
transaction: &mut Transaction<'_, Postgres>,
) -> Result<String, anyhow::Error> {
let uid = uuid::Uuid::parse_str(uid)?;
let row = sqlx::query!(
r#"
SELECT email
@ -219,37 +209,47 @@ pub struct ChangePasswordRequest {
pub new_password_confirm: String,
}
#[derive(Debug, Clone)]
pub struct LoggedUser {
uid: Secret<String>,
#[derive(Clone)]
pub struct WrapI64(i64);
impl Default for WrapI64 {
fn default() -> Self {
Self(0)
}
}
impl Copy for WrapI64 {}
impl DefaultIsZeroes for WrapI64 {}
impl DebugSecret for WrapI64 {}
impl CloneableSecret for WrapI64 {}
impl std::ops::Deref for WrapI64 {
type Target = i64;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct LoggedUser(Secret<WrapI64>);
impl From<Claim> for LoggedUser {
fn from(c: Claim) -> Self {
Self {
uid: Secret::new(c.uid),
}
Self(Secret::new(WrapI64(c.uid)))
}
}
impl LoggedUser {
pub fn new(uid: String) -> Self {
Self {
uid: Secret::new(uid),
}
pub fn new(uid: i64) -> Self {
Self(Secret::new(WrapI64(uid)))
}
pub fn from_token(server_key: &Secret<String>, token: &str) -> Result<Self, AuthError> {
let user: LoggedUser = Token::decode_token(server_key, token)?.into();
Ok(user)
}
}
impl std::ops::Deref for LoggedUser {
type Target = Secret<String>;
fn deref(&self) -> &Self::Target {
&self.uid
pub fn expose_secret(&self) -> &i64 {
self.0.expose_secret()
}
}
@ -259,15 +259,12 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claim {
iss: String,
uid: String,
uid: i64,
}
impl Claim {
pub fn with_user_id(uid: &str) -> Self {
Self {
iss: domain(),
uid: uid.to_string(),
}
pub fn with_user_id(uid: i64) -> Self {
Self { iss: domain(), uid }
}
}
@ -281,8 +278,8 @@ impl Zeroize for Token {
}
impl Token {
pub fn create_token(user_id: &str, server_key: &Secret<String>) -> Result<Self, AuthError> {
let claim = Claim::with_user_id(user_id);
pub fn create_token(uid: i64, server_key: &Secret<String>) -> Result<Self, AuthError> {
let claim = Claim::with_user_id(uid);
let token = create_token(
server_key.expose_secret().as_str(),
claim,
@ -317,7 +314,7 @@ pub fn logged_user_from_request(
pub fn uid_from_request(
request: &HttpRequest,
server_key: &Secret<String>,
) -> Result<Secret<String>, AuthError> {
) -> Result<Secret<i64>, AuthError> {
match request.headers().get(HEADER_TOKEN) {
Some(header) => match header.to_str() {
Ok(val) => Token::decode_token(server_key, val).map(|claim| Secret::new(claim.uid)),

View File

@ -1,7 +1,7 @@
use crate::component::auth::LoggedUser;
use crate::config::config::Config;
use chrono::{DateTime, Utc};
use secrecy::ExposeSecret;
use snowflake::Snowflake;
use sqlx::PgPool;
use std::collections::BTreeMap;
use std::sync::Arc;
@ -12,12 +12,17 @@ pub struct State {
pub pg_pool: PgPool,
pub config: Arc<Config>,
pub user: Arc<RwLock<UserCache>>,
pub id_gen: Arc<RwLock<Snowflake>>,
}
impl State {
pub async fn load_users(_pool: &PgPool) {
todo!()
}
pub async fn next_user_id(&self) -> i64 {
self.id_gen.write().await.next_id()
}
}
#[derive(Clone, Debug, Copy)]
@ -31,7 +36,7 @@ pub const EXPIRED_DURATION_DAYS: i64 = 30;
#[derive(Debug, Default)]
pub struct UserCache {
// Keep track the user authentication state
user: BTreeMap<String, AuthStatus>,
user: BTreeMap<i64, AuthStatus>,
}
impl UserCache {