chore: check file md5 before import (#895)
This commit is contained in:
parent
7d6d1fd151
commit
3623d9f296
|
|
@ -806,6 +806,7 @@ dependencies = [
|
|||
"aws-config",
|
||||
"aws-sdk-s3",
|
||||
"axum 0.7.5",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"collab",
|
||||
"collab-database",
|
||||
|
|
@ -818,6 +819,7 @@ dependencies = [
|
|||
"futures",
|
||||
"infra",
|
||||
"mailer",
|
||||
"md5",
|
||||
"mime_guess",
|
||||
"redis 0.25.4",
|
||||
"secrecy",
|
||||
|
|
|
|||
|
|
@ -46,5 +46,7 @@ mime_guess = "2.0"
|
|||
bytes.workspace = true
|
||||
uuid.workspace = true
|
||||
mailer.workspace = true
|
||||
md5.workspace = true
|
||||
base64.workspace = true
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -397,7 +397,14 @@ async fn download_and_unzip_file(
|
|||
.map_err(|err| ImportError::Internal(err.into()))?;
|
||||
let buffer_size = buffer_size_from_content_length(content_length);
|
||||
|
||||
let zip_reader = get_zip_reader(storage_dir, stream, buffer_size, streaming).await?;
|
||||
let zip_reader = get_zip_reader(
|
||||
storage_dir,
|
||||
stream,
|
||||
buffer_size,
|
||||
streaming,
|
||||
&import_task.md5_base64,
|
||||
)
|
||||
.await?;
|
||||
let unique_file_name = Uuid::new_v4().to_string();
|
||||
let output_file_path = storage_dir.join(unique_file_name);
|
||||
fs::create_dir_all(&output_file_path)
|
||||
|
|
@ -431,6 +438,7 @@ async fn get_zip_reader(
|
|||
stream: Box<dyn AsyncBufRead + Unpin + Send>,
|
||||
buffer_size: usize,
|
||||
streaming: bool,
|
||||
file_md5_base64: &Option<String>,
|
||||
) -> Result<ZipReader, ImportError> {
|
||||
let zip_reader = if streaming {
|
||||
// Occasionally, we encounter the error 'unable to locate the end of central directory record'
|
||||
|
|
@ -444,7 +452,7 @@ async fn get_zip_reader(
|
|||
file: None,
|
||||
}
|
||||
} else {
|
||||
let file = download_file(storage_dir, stream).await?;
|
||||
let file = download_file(storage_dir, stream, file_md5_base64).await?;
|
||||
let handle = fs::File::open(&file)
|
||||
.await
|
||||
.map_err(|err| ImportError::Internal(err.into()))?;
|
||||
|
|
@ -996,6 +1004,8 @@ pub struct NotionImportTask {
|
|||
pub workspace_name: String,
|
||||
pub s3_key: String,
|
||||
pub host: String,
|
||||
#[serde(default)]
|
||||
pub md5_base64: Option<String>,
|
||||
}
|
||||
impl Display for NotionImportTask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ use anyhow::Result;
|
|||
use aws_sdk_s3::operation::get_object::GetObjectError;
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use axum::async_trait;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use futures::AsyncReadExt;
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
|
@ -170,16 +172,19 @@ impl Drop for AutoRemoveDownloadedFile {
|
|||
pub async fn download_file(
|
||||
storage_dir: &Path,
|
||||
stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
|
||||
expected_md5_base64: &Option<String>,
|
||||
) -> Result<AutoRemoveDownloadedFile, anyhow::Error> {
|
||||
let zip_file_path = storage_dir.join(format!("{}.zip", Uuid::new_v4()));
|
||||
write_stream_to_file(&zip_file_path, stream).await?;
|
||||
write_stream_to_file(&zip_file_path, expected_md5_base64, stream).await?;
|
||||
Ok(AutoRemoveDownloadedFile(zip_file_path))
|
||||
}
|
||||
|
||||
pub async fn write_stream_to_file(
|
||||
file_path: &PathBuf,
|
||||
expected_md5_base64: &Option<String>,
|
||||
mut stream: Box<dyn futures::AsyncBufRead + Unpin + Send>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let mut context = md5::Context::new();
|
||||
let mut file = File::create(file_path).await?;
|
||||
let mut buffer = vec![0u8; 1_048_576];
|
||||
loop {
|
||||
|
|
@ -187,9 +192,22 @@ pub async fn write_stream_to_file(
|
|||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
context.consume(&buffer[..bytes_read]);
|
||||
file.write_all(&buffer[..bytes_read]).await?;
|
||||
}
|
||||
file.flush().await?;
|
||||
|
||||
let digest = context.compute();
|
||||
let md5_base64 = STANDARD.encode(digest.as_ref());
|
||||
if let Some(expected_md5) = expected_md5_base64 {
|
||||
if md5_base64 != *expected_md5 {
|
||||
error!(
|
||||
"[Import]: MD5 mismatch, expected: {}, current: {}",
|
||||
expected_md5, md5_base64
|
||||
);
|
||||
return Err(anyhow!("MD5 mismatch"));
|
||||
}
|
||||
}
|
||||
|
||||
file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ use base64::Engine;
|
|||
use database::user::select_name_and_email_from_uuid;
|
||||
use database::workspace::select_import_task;
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::json;
|
||||
use shared_entity::dto::import_dto::{ImportTaskDetail, ImportTaskStatus, UserImportTask};
|
||||
use shared_entity::response::{AppResponse, JsonAppResponse};
|
||||
use std::env::temp_dir;
|
||||
|
|
@ -76,7 +77,7 @@ async fn import_data_handler(
|
|||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let md5 = req
|
||||
let md5_base64 = req
|
||||
.headers()
|
||||
.get("X-Content-MD5")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
|
|
@ -88,19 +89,19 @@ async fn import_data_handler(
|
|||
trace!(
|
||||
"[Import] content length: {}, content md5: {}",
|
||||
content_length,
|
||||
md5
|
||||
md5_base64
|
||||
);
|
||||
if file.md5_base64 != md5 {
|
||||
if file.md5_base64 != md5_base64 {
|
||||
trace!(
|
||||
"Import file fail. The Content-MD5:{} doesn't match file md5:{}",
|
||||
md5,
|
||||
md5_base64,
|
||||
file.md5_base64
|
||||
);
|
||||
|
||||
return Err(
|
||||
AppError::InvalidRequest(format!(
|
||||
"Content-MD5:{} doesn't match file md5:{}",
|
||||
md5, file.md5_base64
|
||||
md5_base64, file.md5_base64
|
||||
))
|
||||
.into(),
|
||||
);
|
||||
|
|
@ -145,14 +146,29 @@ async fn import_data_handler(
|
|||
.put_blob_as_content_type(&workspace_id, stream, "application/zip")
|
||||
.await?;
|
||||
|
||||
// This task will be deserialized into ImportTask
|
||||
let task_id = Uuid::new_v4();
|
||||
let task = json!({
|
||||
"notion": {
|
||||
"uid": uid,
|
||||
"user_name": user_name,
|
||||
"user_email": user_email,
|
||||
"task_id": task_id.to_string(),
|
||||
"workspace_id": workspace_id,
|
||||
"s3_key": workspace_id,
|
||||
"host": host,
|
||||
"workspace_name": &file.name,
|
||||
"md5_base64": md5_base64,
|
||||
}
|
||||
});
|
||||
|
||||
create_upload_task(
|
||||
uid,
|
||||
&user_name,
|
||||
&user_email,
|
||||
&workspace_id,
|
||||
&file.name,
|
||||
file.size,
|
||||
task_id,
|
||||
task,
|
||||
&host,
|
||||
&workspace_id,
|
||||
file.size,
|
||||
&state.redis_connection_manager,
|
||||
&state.pg_pool,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -678,17 +678,14 @@ async fn check_if_user_is_allowed_to_delete_comment(
|
|||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn create_upload_task(
|
||||
uid: i64,
|
||||
user_name: &str,
|
||||
user_email: &str,
|
||||
workspace_id: &str,
|
||||
workspace_name: &str,
|
||||
file_size: usize,
|
||||
task_id: Uuid,
|
||||
task: serde_json::Value,
|
||||
host: &str,
|
||||
workspace_id: &str,
|
||||
file_size: usize,
|
||||
redis_client: &RedisConnectionManager,
|
||||
pg_pool: &PgPool,
|
||||
) -> Result<(), AppError> {
|
||||
let task_id = Uuid::new_v4();
|
||||
|
||||
// Insert the task into the database
|
||||
insert_import_task(
|
||||
task_id,
|
||||
|
|
@ -700,19 +697,6 @@ pub async fn create_upload_task(
|
|||
)
|
||||
.await?;
|
||||
|
||||
// This task will be deserialized into ImportTask
|
||||
let task = json!({
|
||||
"notion": {
|
||||
"uid": uid,
|
||||
"user_name": user_name,
|
||||
"user_email": user_email,
|
||||
"task_id": task_id,
|
||||
"workspace_id": workspace_id,
|
||||
"s3_key": workspace_id,
|
||||
"host": host,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
});
|
||||
let _: () = redis_client
|
||||
.clone()
|
||||
.xadd("import_task_stream", "*", &[("task", task.to_string())])
|
||||
|
|
|
|||
Loading…
Reference in New Issue