chore: set max import zip file size (#1011)

* chore: set max import zip file size

* chore: fix test
This commit is contained in:
Nathan.fooo 2024-11-20 14:07:36 +08:00 committed by GitHub
parent afeaeb7796
commit 1e18180e9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 76 additions and 3 deletions

View File

@ -107,6 +107,12 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err
.parse::<u64>()
.unwrap_or(10);
// Maximum file size for import
let maximum_import_file_size =
get_env_var("APPFLOWY_WORKER_MAX_IMPORT_FILE_SIZE", "1_000_000_000")
.parse::<u64>()
.unwrap_or(1_000_000_000);
let import_worker_fut = local_set.run_until(run_import_worker(
state.pg_pool.clone(),
state.redis_client.clone(),
@ -115,6 +121,7 @@ pub async fn create_app(listener: TcpListener, config: Config) -> Result<(), Err
Arc::new(email_notifier),
"import_task_stream",
tick_interval,
maximum_import_file_size,
));
let app = Router::new()

View File

@ -40,6 +40,15 @@ pub enum ImportError {
#[error("Upload file expired")]
UploadFileExpire,
#[error("Please upgrade to the latest version of the app")]
UpgradeToLatestVersion(String),
#[error("Upload file too large")]
UploadFileTooLarge {
file_size_in_mb: f64,
max_size_in_mb: f64,
},
#[error(transparent)]
Internal(#[from] anyhow::Error),
}
@ -184,6 +193,27 @@ impl ImportError {
format!("Task ID: {} - Upload file expired", task_id),
)
}
ImportError::UpgradeToLatestVersion(s) => {
(
format!(
"Task ID: {} - {}, please upgrade to the latest version of the app to import this file",
task_id,
s,
),
format!("Task ID: {} - Upgrade to latest version", task_id),
)
}
ImportError::UploadFileTooLarge{ file_size_in_mb, max_size_in_mb}=> {
(
format!(
"Task ID: {} - The file size is too large. The maximum file size allowed is {} MB. Please upload a smaller file.",
task_id,
max_size_in_mb,
),
format!("Task ID: {} - Upload file too large: {} MB", task_id, file_size_in_mb),
)
}
}
}
}

View File

@ -78,6 +78,7 @@ pub async fn run_import_worker(
notifier: Arc<dyn ImportNotifier>,
stream_name: &str,
tick_interval_secs: u64,
max_import_file_size: u64,
) -> Result<(), ImportError> {
info!("Starting importer worker");
if let Err(err) = ensure_consumer_group(stream_name, GROUP_NAME, &mut redis_client).await {
@ -95,6 +96,7 @@ pub async fn run_import_worker(
CONSUMER_NAME,
notifier.clone(),
&metrics,
max_import_file_size,
)
.await;
@ -109,6 +111,7 @@ pub async fn run_import_worker(
notifier.clone(),
tick_interval_secs,
&metrics,
max_import_file_size,
)
.await?;
@ -126,6 +129,7 @@ async fn process_un_acked_tasks(
consumer_name: &str,
notifier: Arc<dyn ImportNotifier>,
metrics: &Option<Arc<ImportMetrics>>,
maximum_import_file_size: u64,
) {
// when server restarts, we need to check if there are any unacknowledged tasks
match get_un_ack_tasks(stream_name, group_name, consumer_name, redis_client).await {
@ -139,6 +143,7 @@ async fn process_un_acked_tasks(
pg_pool: pg_pool.clone(),
notifier: notifier.clone(),
metrics: metrics.clone(),
maximum_import_file_size,
};
// Ignore the error here since the consume task will handle the error
let _ = consume_task(
@ -167,6 +172,7 @@ async fn process_upcoming_tasks(
notifier: Arc<dyn ImportNotifier>,
interval_secs: u64,
metrics: &Option<Arc<ImportMetrics>>,
maximum_import_file_size: u64,
) -> Result<(), ImportError> {
let options = StreamReadOptions::default()
.group(group_name, consumer_name)
@ -215,6 +221,7 @@ async fn process_upcoming_tasks(
pg_pool: pg_pool.clone(),
notifier: notifier.clone(),
metrics: metrics.clone(),
maximum_import_file_size,
};
let handle = spawn_local(async move {
@ -254,6 +261,7 @@ struct TaskContext {
pg_pool: PgPool,
notifier: Arc<dyn ImportNotifier>,
metrics: Option<Arc<ImportMetrics>>,
maximum_import_file_size: u64,
}
#[allow(clippy::too_many_arguments)]
@ -270,6 +278,26 @@ async fn consume_task(
return process_and_ack_task(context, import_task, stream_name, group_name, &entry_id).await;
}
match task.file_size {
None => {
return Err(ImportError::UpgradeToLatestVersion(format!(
"Missing file_size for task: {}",
task.task_id
)))
},
Some(file_size) => {
if file_size > context.maximum_import_file_size as i64 {
let file_size_in_mb = file_size as f64 / 1_048_576.0;
let max_size_in_mb = context.maximum_import_file_size as f64 / 1_048_576.0;
return Err(ImportError::UploadFileTooLarge {
file_size_in_mb,
max_size_in_mb,
});
}
},
}
// Check if the task is expired
if let Err(err) = is_task_expired(task.created_at.unwrap(), task.last_process_at) {
if let Ok(import_record) = select_import_task(&context.pg_pool, &task.task_id).await {
@ -1395,10 +1423,11 @@ pub struct NotionImportTask {
impl Display for NotionImportTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let file_size_mb = self.file_size.map(|size| size as f64 / 1_048_576.0);
write!(
f,
"NotionImportTask {{ task_id: {}, workspace_id: {}, file_size:{:?}, workspace_name: {}, user_name: {}, user_email: {} }}",
self.task_id, self.workspace_id, self.file_size, self.workspace_name, self.user_name, self.user_email
"NotionImportTask {{ task_id: {}, workspace_id: {}, file_size:{:?}MB, workspace_name: {}, user_name: {}, user_email: {} }}",
self.task_id, self.workspace_id, file_size_mb, self.workspace_name, self.user_name, self.user_email
)
}
}

View File

@ -136,6 +136,7 @@ fn run_importer_worker(
tick_interval_secs: u64,
) -> std::thread::JoinHandle<()> {
setup_log();
let max_import_file_size = 1_000_000_000;
std::thread::spawn(move || {
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
@ -148,6 +149,7 @@ fn run_importer_worker(
notifier,
&stream_name,
tick_interval_secs,
max_import_file_size,
));
runtime.block_on(import_worker_fut).unwrap();
})

View File

@ -131,7 +131,12 @@ The Five Dysfunctions of a Team by Patrick Lencioni The Five Dysfunctions of a T
.unwrap()
.score;
assert!(score > 0.9, "score: {}, input:{}", score, answer);
assert!(
score > 0.8,
"expected: 0.8, but got score: {}, input:{}",
score,
answer
);
}
}