This commit is contained in:
OverflowCat 2024-07-09 00:48:13 +08:00 committed by ᡥᠠᡳᡤᡳᠶᠠ ᡥᠠᠯᠠ·ᠨᡝᡴᠣ 猫
commit 6e59348f5b
8 changed files with 2084 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

1447
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

16
Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "stwp"
description = "Tracker for Save the Web Project"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
reqwest = { version = "0.12", features = ["json"] }
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = [
"derive",
] }
serde_json = "1.0"
chrono = "0.4.38"

56
src/archivist.rs Normal file
View File

@ -0,0 +1,56 @@
use std::fs::File;
use std::io::{self, BufRead, Write};
use std::path::Path;
use std::env;
const CONFIG_FILE: &str = "ARCHIVIST.conf";
pub fn get_archivist() -> String {
if let Some(arch) = read_archivist() {
if !is_safe_string(&arch) {
panic!("ARCHIVIST contains illegal characters");
}
arch
} else {
let arch = new_archivist();
if !is_safe_string(&arch) {
panic!("ARCHIVIST contains illegal characters");
}
arch
}
}
fn new_archivist() -> String {
println!("zh: 初次运行,请输入可以唯一标识您节点的字符串,例如 alice-aws-114。合法字符字母、数字、-、_");
println!("en: This is your first time running this program. Please enter a string that uniquely identifies your node, e.g. alice-aws-114. (Legal characters: letters, numbers, -, _)");
print!("ARCHIVIST: ");
io::stdout().flush().unwrap();
let mut arch = String::new();
io::stdin().read_line(&mut arch).expect("Failed to read input");
let arch = arch.trim().to_string();
let mut file = File::create(CONFIG_FILE).expect("Failed to create file");
writeln!(file, "{}", arch).expect("Failed to write to file");
read_archivist().expect("Failed to get archivist")
}
fn read_archivist() -> Option<String> {
if let Ok(arch) = env::var("ARCHIVIST") {
return Some(arch);
}
if Path::new(CONFIG_FILE).exists() {
let file = File::open(CONFIG_FILE).expect("Failed to open file");
let reader = io::BufReader::new(file);
for line in reader.lines() {
return line.ok();
}
}
None
}
fn is_safe_string(s: &str) -> bool {
s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_')
}

18
src/item.rs Normal file
View File

@ -0,0 +1,18 @@
pub enum ItemIdType {
Str(String),
Int(u64),
}
pub enum ItemStatus {
None,
Str(String),
Int(u64),
}
pub struct Item {
item_id: String,
item_id_type: ItemIdType,
item_status: String,
item_status_type: ItemStatus,
payload: String,
}

79
src/lib.rs Normal file
View File

@ -0,0 +1,79 @@
use std::time::Duration;
pub mod item;
pub mod task;
pub mod archivist;
pub mod project;
pub struct Tracker {
api_base: &'static str,
api_version: String,
ping_client: reqwest::Client,
project_id: String,
http_client: reqwest::Client,
client_version: String,
archivist: String,
}
const TRACKER_NODES: [&str; 9] = [
"http://localhost:8080/", // 测试环境
"https://0.tracker.saveweb.org/",
"https://1.tracker.saveweb.org/",
"https://ipv4.1.tracker.saveweb.org/",
"https://ipv6.1.tracker.saveweb.org/",
"https://2.tracker.saveweb.org/", // 这台宕了
"https://3.tracker.saveweb.org/",
"https://ipv4.3.tracker.saveweb.org/",
"https://ipv6.3.tracker.saveweb.org/",
];
/*
func GetTracker(project_id string, client_version string, archivist string) *Tracker {
t := &Tracker{
API_VERSION: "v1",
PING_client: &http.Client{
Timeout: 10 * time.Second,
},
project_id: project_id,
HTTP_client: &http.Client{
Timeout: 120 * time.Second,
},
client_version: client_version,
archivist: archivist,
__gzPool: sync.Pool{
New: func() interface{} {
gz, err := gzip.NewWriterLevel(nil, gzip.BestCompression)
if err != nil {
panic(err)
}
return gz
},
},
}
return t
}
*/
#[tokio::main]
pub async fn get_tracker(project_id: &str, client_version: &str, archivist: &str) -> Result<Tracker, Box<dyn std::error::Error>> {
Ok(
Tracker {
api_base: TRACKER_NODES[2],
api_version: "v1".into(),
ping_client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()?,
project_id: project_id.to_string(),
http_client: reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.build()?,
client_version: client_version.to_string(),
archivist: archivist.to_string(),
})
}
impl Tracker {
fn start_select_tracker_background(&self) {
// todo
}
}

134
src/project.rs Normal file
View File

@ -0,0 +1,134 @@
use serde::Deserialize;
use crate::Tracker;
#[derive(Debug, Deserialize)]
pub struct ProjectMeta {
identifier: String,
slug: String,
icon: String,
deadline: String,
}
#[derive(Debug, Deserialize)]
pub struct ProjectStatus {
public: bool,
paused: bool,
}
#[derive(Debug, Deserialize)]
pub struct ProjectClient {
version: String,
claim_task_delay: f64, // 用来做 QoS 的
}
#[derive(Debug, Deserialize)]
pub struct ProjectMongodb {
db_name: String,
item_collection: String,
queue_collection: String,
custom_doc_id_name: String,
}
#[derive(Debug, Deserialize)]
pub struct Project {
meta: ProjectMeta,
status: ProjectStatus,
client: ProjectClient,
mongodb: ProjectMongodb,
}
impl Tracker {
// 我写的是先同步获取一次 project ,然后后台每一分钟获取一次,然后超过几分钟没有正常拿到 project就 panic
// 我先不管后台的,跑起来再说
// 草
// 中肯的
// pub async fn project() TODO
pub async fn fetch_project(&self) -> Result<Project, Box<dyn std::error::Error>> {
println!("fetch_project... {}", self.project_id);
// curl -X POST https://0.tracker.saveweb.org/v1/project/test
let url = format!("{}/project/{}", self.api_base, self.project_id);
let res = self.http_client.post(&url).send().await?;
// parse response as json
let project: Project = serde_json::from_str(&res.text().await?)?;
Ok(project)
}
}
/*
package savewebtracker
func (t *Tracker) Project() (proj Project) {
if time.Since(t.__project_last_fetched) <= 3*time.Minute {
return *t.__project
}
t.StartFetchProjectBackground()
for t.__project == nil { // initial fetch
time.Sleep(1 * time.Second)
if t.__project != nil { // fetch success
return t.Project()
}
}
for { // not nil, but outdated
if time.Since(t.__project_last_fetched) > 5*time.Minute { // over 5 minutes, abort
panic("all fetch failed for 5 minutes")
}
if time.Since(t.__project_last_fetched) <= 3*time.Minute { // not outdated anymore
return *t.__project
}
go t.FetchProject(5 * time.Second) // short timeout
time.Sleep(8 * time.Second)
}
}
func (t *Tracker) StartFetchProjectBackground() *Tracker {
if t.__background_fetch_proj {
return t
}
t.__background_fetch_proj = true
go func() {
for {
go t.FetchProject(20 * time.Second)
time.Sleep(1 * time.Minute)
}
}()
return t
}
func (t *Tracker) FetchProject(timeout time.Duration) (proj *Project, err error) {
fmt.Println("[client->tracker] fetch_project... ", t.project_id)
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
time.AfterFunc(timeout, func() {
cancel()
})
req, err := http.NewRequestWithContext(ctx, "POST", t.API_BASE+t.API_VERSION+"/project/"+t.project_id, nil)
if err != nil {
log.Print(err)
return nil, err
}
r, err := t.HTTP_client.Do(req)
if err != nil {
log.Print(err)
return nil, err
}
defer r.Body.Close()
if r.StatusCode != 200 {
return nil, errors.New("status code not 200")
}
proj = &Project{}
err = json.NewDecoder(r.Body).Decode(proj)
if err != nil {
return nil, err
}
t.__project = proj
t.__project_last_fetched = time.Now()
fmt.Println("[client<-tracker] fetch_project. ", t.project_id)
return proj, nil
}
*/

333
src/task.rs Normal file
View File

@ -0,0 +1,333 @@
use chrono::TimeZone;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug, Display};
#[derive(Debug, Serialize, Deserialize)]
pub enum Status {
#[serde(rename = "TODO")]
Todo,
#[serde(rename = "PROCESSING")]
Processing,
#[serde(rename = "DONE")]
Done,
#[serde(rename = "FAIL")]
Fail,
/// 特殊: 任务冻结 (把一些 的状态设成 FEZZ防止反复 re-queue
#[serde(rename = "FEZZ")]
Fezz,
} //每个项目的状态都可以自己定义
// 只有 TODO PROCESSING 是必须的
//草
impl Display for Status {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(self, f)
}
}
pub type ObjectID = String;
#[derive(Debug, Serialize, Deserialize)]
pub enum Id {
Int(i64),
Str(String),
}
impl Display for Id {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Id::Int(i) => write!(f, "{}", i),
Id::Str(s) => write!(f, "{}", s),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Task {
obj_id: ObjectID,
id: Id,
status: Status,
archivist: String,
claimed_at: Option<String>,
updated_at: Option<String>,
}
// pub fn get_datetime(d: String) -> chrono::DateTime<chrono::Utc> {
// 2024-06-06T15:45:00.008Z
// let t = TimeZone::from
// return t;
// }
impl Tracker {
pub async fn claim_task(with_delay: bool) -> Option<Task> {
if with_delay {
tokio::time::sleep(tokio::time::Duration::from_secs(1) /* TODO */).await;
}
let resp = reqwest::get("https://www.rust-lang.org").await?;
return Task::after_claim_task(resp).await;
}
async fn after_claim_task(r: reqwest::Response) -> Option<Task> {
if r.status() == 404 {
return None;
}
if r.status() == 200 {
let task: Option<Task> = r.json().await.ok();
return task;
}
let body = r.text().await.unwrap();
panic!("{}", body);
}
pub async fn update_task(&self, task_id: Id, to_status: Status) -> String {
let mut post_data = std::collections::HashMap::new();
post_data.insert("status", to_status.to_string());
post_data.insert("task_id_type", task_id.to_string());
let url = format!("{}}/{}/{}", self.obj_id, self.archivist);
// let resp
// let resp = reqwest::post().form(&post_data).send().await?;
return after_update_task(resp).await.unwrap();
}
}
async fn after_update_task(r: reqwest::Response) -> Option<String> {
let status = r.status();
let body = r.text().await.ok()?;
if status == 200 {
Some(body)
}
/* if r.status() == 400 { panic!(body); } */
else {
panic!("{}", body);
}
}
/*
type ObjectID string
type DatetimeUTC string
func (d DatetimeUTC) GetDatetime() time.Time {
// 2024-06-06T15:45:00.008Z
t, err := time.Parse(time.RFC3339, string(d))
if err != nil {
panic(err)
}
return t
}
var (
ErrorClientVersionOutdated = errors.New("client version outdated")
ENABLE_GZIP = true
)
func (t *Tracker) ClaimTask(with_delay bool) *Task {
if with_delay {
t._claim_wait_lock.Lock()
time.Sleep(time.Duration(t.Project().Client.ClaimTaskDelay * float64(time.Second)))
t._claim_wait_lock.Unlock()
}
resp, err := t.HTTP_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/claim_task", "", nil)
if err != nil {
panic(err)
}
return _after_claim_task(resp)
}
// 无任务返回 nil
func _after_claim_task(r *http.Response) *Task {
if r.StatusCode == 404 {
return nil // 无任务
}
if r.StatusCode == 200 {
task := Task{}
err := json.NewDecoder(r.Body).Decode(&task)
if err != nil {
panic(err)
}
var idInt int
var idString string
if err := json.Unmarshal(task.Id_raw, &idInt); err == nil {
idString = fmt.Sprintf("%d", idInt)
task.Id = idString
task.Id_type = "int"
} else if err := json.Unmarshal(task.Id_raw, &idString); err == nil {
task.Id = idString
task.Id_type = "str"
} else {
panic(err)
}
return &task
}
BodyBytes, _ := io.ReadAll(r.Body)
panic(string(BodyBytes))
}
func (t *Tracker) UpdateTask(task_id string, id_type string, to_status Status) string {
postData := url.Values{}
postData.Set("status", string(to_status))
postData.Set("task_id_type", id_type)
if !to_status.Validate() {
fmt.Println("invalid status, to_status:", to_status)
panic("invalid status")
}
resp, err := t.HTTP_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/update_task/"+task_id, "application/x-www-form-urlencoded", strings.NewReader(postData.Encode()))
if err != nil {
panic(err)
}
return _after_update_task(resp)
}
func _after_update_task(r *http.Response) string {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
text := string(bodyBytes)
if r.StatusCode == 200 {
return text
}
if r.StatusCode == 400 {
panic(text)
}
fmt.Println(r.StatusCode, r.Request.URL, text)
panic(text)
}
func (t *Tracker) InsertMany(Items []Item) string {
if len(Items) == 0 {
return "len(Items) == 0, nothing to insert"
}
req_url := t.API_BASE + t.API_VERSION + "/project/" + t.project_id + "/" + t.client_version + "/" + t.archivist + "/insert_many/" + fmt.Sprintf("%d", len(Items))
items_json_str, err := json.Marshal(Items)
if err != nil {
panic(err)
}
len_encodedData := len(items_json_str)
gzBuf, err := t.GzCompress(items_json_str)
if err != nil {
panic(err)
}
req := &http.Request{}
if ENABLE_GZIP && float64(gzBuf.Len())/float64(len_encodedData) < 0.95 { // good compression rate
req, err = http.NewRequest("POST", req_url, gzBuf)
if err != nil {
panic(err)
}
req.Header.Set("Content-Encoding", "gzip")
} else {
req, err = http.NewRequest("POST", req_url, bytes.NewReader(items_json_str))
if err != nil {
panic(err)
}
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "* / *")
resp, err := t.HTTP_client.Do(req)
if err != nil {
panic(err)
}
return _after_insert_item(resp)
}
func (t *Tracker) GzCompress(data []byte) (*bytes.Buffer, error) {
gzBuf := &bytes.Buffer{}
gz := t.__gzPool.Get().(*gzip.Writer)
defer t.__gzPool.Put(gz)
defer gz.Reset(io.Discard)
defer gz.Close()
gz.Reset(gzBuf)
if _, err := gz.Write(data); err != nil {
return nil, err
}
if err := gz.Flush(); err != nil {
return nil, err
}
gz.Close()
return gzBuf, nil
}
func (t *Tracker) InsertItem(task Task, item_status string, status_type string, payload string) string {
if status_type != "int" && status_type != "str" && status_type != "None" {
panic("status must be int, str or None")
}
req_url := t.API_BASE + t.API_VERSION + "/project/" + t.project_id + "/" + t.client_version + "/" + t.archivist + "/insert_item/" + task.Id
var err error
item := Item{
Item_id: task.Id,
Item_id_type: task.Id_type,
Item_status: item_status,
Item_status_type: status_type,
Payload: payload,
}
data, err := json.Marshal(item)
if err != nil {
panic(err)
}
len_data := len(data)
gzBuf, err := t.GzCompress(data)
if err != nil {
panic(err)
}
// fmt.Printf("compressed %d -> %d \n", len_encodedData, gzBuf.Len())
req := &http.Request{}
if ENABLE_GZIP && float64(gzBuf.Len())/float64(len_data) < 0.95 { // good compression rate
req, err = http.NewRequest("POST", req_url, gzBuf)
if err != nil {
panic(err)
}
req.Header.Set("Content-Encoding", "gzip")
} else {
req, err = http.NewRequest("POST", req_url, bytes.NewReader(data))
if err != nil {
panic(err)
}
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "* / *")
resp, err := t.HTTP_client.Do(req)
if err != nil {
panic(err)
}
return _after_insert_item(resp)
}
func _after_insert_item(r *http.Response) string {
defer r.Body.Close()
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
text := string(bodyBytes)
if r.StatusCode == 200 {
return text
}
fmt.Println(r.StatusCode, r.Request.URL, text)
panic(text)
}
*/