From 53de2a0b1c708ed2bd043c05b3849c3c6ad4e328 Mon Sep 17 00:00:00 2001 From: OverflowCat Date: Tue, 9 Jul 2024 18:38:48 +0800 Subject: [PATCH] --no-edit --- Cargo.lock | 51 ++++++++++++++++++++++++++ Cargo.toml | 1 + src/archivist.rs | 52 ++++++++++++++++++++++----- src/background.rs | 38 ++++++++++++++++++++ src/item.rs | 29 ++++++++++++--- src/lib.rs | 92 ++++++++++++++++++++--------------------------- src/project.rs | 19 ++++++---- src/task.rs | 54 ++++++++++++++-------------- 8 files changed, 236 insertions(+), 100 deletions(-) create mode 100644 src/background.rs diff --git a/Cargo.lock b/Cargo.lock index caceac2..7cfaf1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -192,6 +192,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -199,6 +214,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -207,6 +223,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -225,10 +269,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -954,6 +1004,7 @@ name = "stwp" version = "0.1.0" dependencies = [ "chrono", + "futures", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 9ac9c71..bb708f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ serde = { version = "1", features = [ ] } serde_json = "1.0" chrono = "0.4.38" +futures = "0.3.30" diff --git a/src/archivist.rs b/src/archivist.rs index 478b235..86aaaf7 100644 --- a/src/archivist.rs +++ b/src/archivist.rs @@ -1,7 +1,7 @@ +use std::env; use std::fs::File; use std::io::{self, BufRead, Write}; use std::path::Path; -use std::env; const CONFIG_FILE: &str = "ARCHIVIST.conf"; @@ -21,13 +21,15 @@ pub fn get_archivist() -> String { } fn new_archivist() -> String { - println!("zh: 初次运行,请输入可以唯一标识您节点的字符串,例如 alice-aws-114。(合法字符:字母、数字、-、_)"); - println!("en: This is your first time running this program. Please enter a string that uniquely identifies your node, e.g. alice-aws-114. (Legal characters: letters, numbers, -, _)"); + println!("zh: 初次运行,请输入可以唯一标识您节点的字符串,例如 neko-stwp-114。(合法字符:字母、数字、-、_)"); + println!("en: This is your first time running this program. Please enter a string that uniquely identifies your node, e.g. neko-stwp-114. (Legal characters: letters, numbers, -, _)"); print!("ARCHIVIST: "); io::stdout().flush().unwrap(); let mut arch = String::new(); - io::stdin().read_line(&mut arch).expect("Failed to read input"); + io::stdin() + .read_line(&mut arch) + .expect("Failed to read input"); let arch = arch.trim().to_string(); let mut file = File::create(CONFIG_FILE).expect("Failed to create file"); @@ -42,15 +44,47 @@ fn read_archivist() -> Option { } if Path::new(CONFIG_FILE).exists() { - let file = File::open(CONFIG_FILE).expect("Failed to open file"); - let reader = io::BufReader::new(file); - for line in reader.lines() { - return line.ok(); + let file = File::open(CONFIG_FILE); + if let Ok(file) = file { + let reader = io::BufReader::new(file); + for line in reader.lines() { + return line.ok(); + } } } None } fn is_safe_string(s: &str) -> bool { - s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_') + s.chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_safe_string() { + assert!(is_safe_string("neko-stwp-114")); + assert!(is_safe_string("neko_stwp_114")); + assert!(!is_safe_string("牢鼠")); + assert!(!is_safe_string("laoshu stwp 514")); + assert!(!is_safe_string("neko@stwp-114")); + } + + #[test] + fn test_create_config_file() { + let arch = "neko-stwp-114"; + if Path::new(CONFIG_FILE).exists() { + std::fs::rename(CONFIG_FILE, "ARCHIVIST.conf.bak").expect("Failed to rename file"); + } + let mut file = File::create(CONFIG_FILE).expect("Failed to create file"); + writeln!(file, "{}", arch).expect("Failed to write to file"); + + let arch = read_archivist().expect("Failed to get archivist"); + assert_eq!(arch, "neko-stwp-114"); + std::fs::remove_file(CONFIG_FILE).expect("Failed to remove file"); + std::fs::rename("ARCHIVIST.conf.bak", CONFIG_FILE).expect("Failed to rename file"); + } } diff --git a/src/background.rs b/src/background.rs new file mode 100644 index 0000000..61c94cd --- /dev/null +++ b/src/background.rs @@ -0,0 +1,38 @@ +use std::sync::Arc; + +use tokio::{sync::RwLock, time::Duration}; + +use crate::{Tracker, TRACKER_NODES}; + +impl Tracker { + pub fn start_select_tracker_background(api_base: Arc>) { + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(60)).await; + let mut write_guard: tokio::sync::RwLockWriteGuard<&str> = api_base.write().await; + use futures::future::JoinAll; + println!("Selecting best tracker..."); + let durations = TRACKER_NODES + .iter() + .map(|&node| Self::get_ping(node)) + .collect::>() + .await; + let best_node = durations + .iter() + .enumerate() + .min_by_key(|(_, &elapsed)| elapsed) + .map(|(idx, _)| TRACKER_NODES[idx]) + .unwrap(); + *write_guard = best_node; + } + }); + } + + async fn get_ping(node: &str) -> Duration { + let url = format!("{}/ping", node); + let time = std::time::Instant::now(); + let resp = reqwest::get(&url).await.unwrap().text().await.unwrap(); + println!("ping {} got {}, elapsed: {:?}", node, resp, time.elapsed()); + time.elapsed() + } +} diff --git a/src/item.rs b/src/item.rs index 38c8632..ac26600 100644 --- a/src/item.rs +++ b/src/item.rs @@ -2,11 +2,12 @@ use crate::task::Id; use serde::Serialize; #[derive(Debug, Serialize, Clone)] -#[serde(rename_all_fields(serialize = "lowercase"))] +#[serde(rename_all(serialize = "lowercase"))] pub enum ItemIdType { Int, Str, } + impl From<&Id> for ItemIdType { fn from(s: &Id) -> Self { match s { @@ -16,10 +17,7 @@ impl From<&Id> for ItemIdType { } } - - #[derive(Debug, Serialize, Clone)] -// #[serde(rename_all_fields(serialize = "lowercase"))] pub enum ItemStatusType { None, #[serde(rename = "int")] @@ -36,3 +34,26 @@ pub struct Item { pub item_status_type: ItemStatusType, pub payload: String, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_item_id_json() { + let item_id = ItemIdType::Int; + let json = serde_json::to_string(&item_id).unwrap(); + assert_eq!(json, r#""int""#); + } + + #[test] + fn test_item_status_json() { + let item_status = ItemStatusType::Int; + let json = serde_json::to_string(&item_status).unwrap(); + assert_eq!(json, r#""int""#); + let json = serde_json::to_string(&ItemStatusType::Str).unwrap(); + assert_eq!(json, r#""str""#); + let json = serde_json::to_string(&ItemStatusType::None).unwrap(); + assert_eq!(json, r#""None""#); + } +} diff --git a/src/lib.rs b/src/lib.rs index ccd0972..0c87c1e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,29 @@ -use std::time::Duration; - use project::Project; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio::time::Duration; pub mod archivist; pub mod item; pub mod project; pub mod task; +pub mod background; + +const TRACKER_NODES: [&str; 7] = [ + // "http://localhost:8080", // 测试环境 + "https://0.tracker.saveweb.org", + "https://1.tracker.saveweb.org", + "https://ipv4.1.tracker.saveweb.org", + "https://ipv6.1.tracker.saveweb.org", + // "https://2.tracker.saveweb.org", // 这台宕了 + "https://3.tracker.saveweb.org", + "https://ipv4.3.tracker.saveweb.org", + "https://ipv6.3.tracker.saveweb.org", +]; pub struct Tracker { - api_base: &'static str, - api_version: String, - // ping_client: reqwest::Client, // TODO + api_base: Arc>, + api_version: &'static str, project_id: String, http_client: reqwest::Client, client_version: String, @@ -18,42 +31,25 @@ pub struct Tracker { project: Option, } -const TRACKER_NODES: [&str; 9] = [ - "http://localhost:8080", // 测试环境 - "https://0.tracker.saveweb.org", - "https://1.tracker.saveweb.org", - "https://ipv4.1.tracker.saveweb.org", - "https://ipv6.1.tracker.saveweb.org", - "https://2.tracker.saveweb.org", // 这台宕了 - "https://3.tracker.saveweb.org", - "https://ipv4.3.tracker.saveweb.org", - "https://ipv6.3.tracker.saveweb.org", -]; - -pub fn get_tracker( - project_id: &str, - client_version: &str, - archivist: &str, -) -> Result> { - Ok(Tracker { - api_base: TRACKER_NODES[2], - api_version: "v1".into(), - // ping_client: reqwest::Client::builder() - // .timeout(Duration::from_secs(10)) - // .build()?, - project_id: project_id.to_string(), - http_client: reqwest::Client::builder() - .timeout(Duration::from_secs(120)) - .build()?, - client_version: client_version.to_string(), - archivist: archivist.to_string(), - project: None, - }) -} - impl Tracker { - fn start_select_tracker_background(&self) { - // todo + pub fn new( + project_id: String, + client_version: String, + archivist: String, + ) -> Result> { + let api_base = Arc::new(RwLock::new(TRACKER_NODES[1])); + Self::start_select_tracker_background(Arc::clone(&api_base)); + Ok(Tracker { + api_base, + api_version: "v1", + project_id: project_id, + http_client: reqwest::Client::builder() + .timeout(Duration::from_secs(60)) + .build()?, + client_version, + archivist, + project: None, + }) } } @@ -63,12 +59,11 @@ mod tests { #[tokio::test] async fn test_get_tracker() { - let mut tracker = get_tracker("test", "1.1", "neko").unwrap(); - // 但是不知道不加 tokio decorator 会不会有问题 + let mut tracker = Tracker::new("test".into(), "1.1".into(), "neko".into()).unwrap(); let project = tracker.get_project().await; println!("{:?}", project); let task = tracker.claim_task(true).await.unwrap(); - + println!("{:?}", task); let payload = r#"{"hhhh":123123, "f": 123.123}"#.to_string(); @@ -77,13 +72,4 @@ mod tests { .await; println!("{:?}", resp); } - // called `Result::unwrap()` on an `Err` value: Error("invalid type: integer `404`, expected struct Project", line: 1, column: 3) - // can you see terminal? - // yeap - // 我看看后端 -} // 是不是还少抄了什么 - // 写项目 调用的第一个应该是调哪个函数? - -// 就是先 get_tracker() 然后用 tracker 对象 .get_project() -// 意思是 async 了个寂寞? -// 问题不大, get_tracker 不需要 async +} diff --git a/src/project.rs b/src/project.rs index 466881c..a290d45 100644 --- a/src/project.rs +++ b/src/project.rs @@ -2,6 +2,7 @@ use serde::Deserialize; use crate::Tracker; +#[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct ProjectMeta { identifier: String, @@ -10,6 +11,7 @@ pub struct ProjectMeta { deadline: String, } +#[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct ProjectStatus { public: bool, @@ -18,10 +20,11 @@ pub struct ProjectStatus { #[derive(Debug, Deserialize)] pub struct ProjectClient { - version: String, - claim_task_delay: f64, // 用来做 QoS 的 + pub version: String, + pub claim_task_delay: f64, // 用来做 QoS 的 } +#[allow(dead_code)] #[derive(Debug, Deserialize)] pub struct ProjectMongodb { db_name: String, @@ -32,18 +35,20 @@ pub struct ProjectMongodb { #[derive(Debug, Deserialize)] pub struct Project { - meta: ProjectMeta, - status: ProjectStatus, - client: ProjectClient, - mongodb: ProjectMongodb, + pub meta: ProjectMeta, + pub status: ProjectStatus, + pub client: ProjectClient, + pub mongodb: ProjectMongodb, } impl Tracker { pub async fn fetch_project(&self) -> Result> { println!("fetch_project... {}", self.project_id); + let api_base = *self.api_base.read().await; + let url = format!( "{}/{}/project/{}", - self.api_base, self.api_version, self.project_id + api_base, self.api_version, self.project_id ); let res = self.http_client.post(&url).send().await?; // parse response as json diff --git a/src/task.rs b/src/task.rs index 3b2a0f0..686503f 100644 --- a/src/task.rs +++ b/src/task.rs @@ -2,7 +2,10 @@ use reqwest::Response; use serde::{Deserialize, Serialize}; use std::fmt::{self, Debug, Display}; -use crate::{item::{Item, ItemStatusType}, Tracker}; +use crate::{ + item::{Item, ItemStatusType}, + Tracker, +}; #[derive(Debug, Serialize, Deserialize)] pub enum Status { @@ -41,29 +44,34 @@ impl Display for Id { } } -// {"_id":"6663569c658e3647d062680b","archivist":"aaaa","claimed_at":"2024-07-08T18:54:17.463Z","id":23,"statu@OverflowCat ➜ /workspaces/stwp-rs (master) $ :argo test -- --nocapture +/// MongoDB ObjectId +type ObjectId = String; + #[derive(Debug, Serialize, Deserialize)] pub struct Task { - pub _id: String, - pub id: Id, // 也不行,我看看 + pub _id: ObjectId, + pub id: Id, pub status: Status, pub archivist: String, pub claimed_at: Option, pub updated_at: Option, } -// 要不写下测试? -// codespace 的 rust analyzer 好慢 impl Tracker { - pub async fn claim_task(&self, with_delay: bool) -> Option { + pub async fn claim_task(&mut self, with_delay: bool) -> Option { if with_delay { - // tokio::time::sleep(tokio::time::Duration::from_secs(t.project()) /* TODO */).await; + let project = self.get_project().await; + tokio::time::sleep(tokio::time::Duration::from_secs_f64( + project.client.claim_task_delay, + )) + .await; } - // resp, err := t.HTTP_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/claim_task", "", nil) + let api_base = *self.api_base.read().await; + let url = format!( "{}/{}/project/{}/{}/{}/claim_task", - self.api_base, self.api_version, self.project_id, self.client_version, self.archivist + api_base, self.api_version, self.project_id, self.client_version, self.archivist ); println!("{}", url); let resp = self.http_client.post(&url).send().await.unwrap(); @@ -75,10 +83,11 @@ impl Tracker { post_data.insert("status", to_status.to_string()); post_data.insert("task_id_type", task_id.to_string()); - // resp, err := t.HTTP_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/update_task/"+task_id, "application/x-www-form-urlencoded", strings.NewReader(postData.Encode())) + let api_base = *self.api_base.read().await; + let url = format!( "{}/{}/{}/{}/{}/update_task/{}", - self.api_base, + api_base, self.api_version, self.project_id, self.client_version, @@ -99,13 +108,12 @@ impl Tracker { if items.is_empty() { return "len(Items) == 0, nothing to insert".to_string(); } + + let api_base = *self.api_base.read().await; let url = format!( - // req_url := t.API_BASE + t.API_VERSION + "/project/" + t.project_id + "/" + t.client_version + "/" + t.archivist + "/insert_many/" + fmt.Sprintf("%d", len(Items)) "{}/{}/project/{}/{}/{}/insert_many/{}", // TODO: 该找个 path builder 了? - // 今天先不管了 - - self.api_base, + api_base, self.api_version, self.project_id, self.client_version, @@ -130,10 +138,11 @@ impl Tracker { item_status: String, // TODO payload: String, ) -> String { - // req_url := t.API_BASE + t.API_VERSION + "/project/" + t.project_id + "/" + t.client_version + "/" + t.archivist + "/insert_item/" + task.Id + let api_base = *self.api_base.read().await; + let url = format!( "{}/{}/project/{}/{}/{}/insert_item/{}", - self.api_base, + api_base, self.api_version, self.project_id, self.client_version, @@ -150,15 +159,6 @@ impl Tracker { // Payload string `json:"payload" binding:"required"` // } - // 感觉需要定义一个 ForPostItem(what?) 之类的东西…… - // 我后端没有从 json 类型来判断类型。 - - // 我后端写得烂,我的锅 - // 另外就是,我怕遇到 int64/float64+ 的 id,所以全部传 str,然后用 _type 来区分 - // 我看下 serde 文档 - - // client 需要 deserialize Item 吗?还是只发送不读取 - // 只发送ok // 也可以发 HTTP Form let item = Item { item_id: task.id.to_string(),