From 67c2c0dfbc7a91d9edddbd51551433ec97fb6acc Mon Sep 17 00:00:00 2001 From: yzqzss Date: Fri, 7 Jun 2024 17:34:17 +0800 Subject: [PATCH] first version --- src/saveweb_tracker/archivist.py | 17 + src/saveweb_tracker/task.go | 33 ++ src/saveweb_tracker/task.py | 40 +++ src/saveweb_tracker/tracker.go | 192 +++++++++++ src/saveweb_tracker/tracker.py | 421 +++++++++++++++++++++++ src/saveweb_tracker/tracker_project.go | 70 ++++ src/saveweb_tracker/tracker_task.go | 303 ++++++++++++++++ src/saveweb_tracker/tracker_task_test.go | 84 +++++ src/saveweb_tracker/tracker_test.go | 36 ++ tests/tracker_test.py | 18 + 10 files changed, 1214 insertions(+) create mode 100644 src/saveweb_tracker/archivist.py create mode 100644 src/saveweb_tracker/task.go create mode 100644 src/saveweb_tracker/task.py create mode 100644 src/saveweb_tracker/tracker.go create mode 100644 src/saveweb_tracker/tracker.py create mode 100644 src/saveweb_tracker/tracker_project.go create mode 100644 src/saveweb_tracker/tracker_task.go create mode 100644 src/saveweb_tracker/tracker_task_test.go create mode 100644 src/saveweb_tracker/tracker_test.go create mode 100644 tests/tracker_test.py diff --git a/src/saveweb_tracker/archivist.py b/src/saveweb_tracker/archivist.py new file mode 100644 index 0000000..8f55611 --- /dev/null +++ b/src/saveweb_tracker/archivist.py @@ -0,0 +1,17 @@ +import os + +def new_archivist(): + print("zh: 第一次运行,请输入可以唯一标识您节点的字符串,例如 alice-aws-114。(合法字符:字母、数字、-、_)") + print("en: First run, please input a string that can uniquely identify your node (for example: bob-gcloud-514). (Legal characters: letters, numbers, -, _)") + arch = input("ARCHIVIST: ") + with open("ARCHIVIST.conf", "w") as f: + f.write(arch) + return get_archivist() + +def get_archivist(): + if arch := os.getenv("ARCHIVIST", ""): + return arch + if os.path.exists("ARCHIVIST.conf"): + with open("ARCHIVIST.conf", "r") as f: + return f.read().splitlines()[0].strip() + return "" \ No newline at end of file diff --git a/src/saveweb_tracker/task.go b/src/saveweb_tracker/task.go new file mode 100644 index 0000000..3d097e8 --- /dev/null +++ b/src/saveweb_tracker/task.go @@ -0,0 +1,33 @@ +package savewebtracker + +// class Status: +// TODO = "TODO" +// PROCESSING = "PROCESSING" +// DONE = "DONE" +// EMPTY = "EMPTY" # 给 item 用的 + +// # TIMEOUT = "TIMEOUT" # 一直 PROCESSING 的任务,超时 +// FAIL = "FAIL" +// FEZZ = "FEZZ" +// """ 特殊: 任务冻结 """ + +const ( + StatusTODO = "TODO" + StatusPROCESSING = "PROCESSING" + StatusDONE = "DONE" + StatusEMPTY = "EMPTY" + + StatusFAIL = "FAIL" + StatusFEZZ = "FEZZ" +) + +type Status string + +var LIST_STATUS = []Status{ + StatusTODO, + StatusPROCESSING, + StatusDONE, + StatusEMPTY, + StatusFAIL, + StatusFEZZ, +} diff --git a/src/saveweb_tracker/task.py b/src/saveweb_tracker/task.py new file mode 100644 index 0000000..d776453 --- /dev/null +++ b/src/saveweb_tracker/task.py @@ -0,0 +1,40 @@ +from typing import Optional +from dataclasses import dataclass +from datetime import datetime + +class Status: + TODO = "TODO" + PROCESSING = "PROCESSING" + DONE = "DONE" + EMPTY = "EMPTY" # 给 item 用的 + + # TIMEOUT = "TIMEOUT" # 一直 PROCESSING 的任务,超时 + FAIL = "FAIL" + FEZZ = "FEZZ" + """ 特殊: 任务冻结 """ + +@dataclass +class Task: + _id: str + """ ObjectID """ + id: int + status: Status + archivist: str + + claimed_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + + def __post_init__(self): + assert self.status in Status.__dict__.values() + + def __repr__(self): + return f"Task({self.id}, status={self.status})" + + def __init__(self, _id, id, status, archivist, claimed_at, updated_at): + self._id = _id + self.id = id + self.status = status + self.archivist = archivist + self.claimed_at = claimed_at + self.updated_at = updated_at \ No newline at end of file diff --git a/src/saveweb_tracker/tracker.go b/src/saveweb_tracker/tracker.go new file mode 100644 index 0000000..5f8edf6 --- /dev/null +++ b/src/saveweb_tracker/tracker.go @@ -0,0 +1,192 @@ +package savewebtracker + +import ( + "compress/gzip" + "context" + "errors" + "fmt" + "math" + "net/http" + "regexp" + "sync" + "time" +) + +type ProjectMeta struct { + Identifier string `json:"identifier"` + Slug string `json:"slug"` // 一句话项目说明 + Icon string `json:"icon"` // 图标 URL + Deadline string `json:"deadline"` // 截止日期,没有格式要求 +} + +type ProjectStatus struct { + Public bool `json:"public"` // 是否公开。(不在 /projects 列表中列出,但 /project 可以请求) + Paused bool `json:"paused"` // 是否暂停。暂停后不再接受新的 claim_task 请求。但是仍可 update_task 和 insert_item +} + +type ProjectClient struct { + Version string `json:"version"` // 推荐 "大版本.小版本"。版本不对会拒绝各种请求。 + ClaimTaskDelay float32 `json:"claim_task_delay"` // claim_task 之后多久才能再次 claim_task +} + +type ProjectMongodb struct { + DbName string `json:"db_name"` + ItemCollection string `json:"item_collection"` + QueueCollection string `json:"queue_collection"` + CustomDocIDName string `json:"custom_doc_id_name"` // CustomIDName 目前只是为了兼容 lowapk_v2 存档项目而留的字段。它用的 feed_id 而不是 id + // 计划未来等转换好数据库 scheme 之后就弃用。 +} + +type Project struct { + Meta ProjectMeta `json:"meta"` + Status ProjectStatus `json:"status"` + Client ProjectClient `json:"client"` + Mongodb ProjectMongodb `json:"mongodb"` +} + +var ( + TRACKER_NODES = []string{ + "http://localhost:8080/", + "https://0.tracker.saveweb.org/", + "https://1.tracker.saveweb.org/", + "https://2.tracker.saveweb.org/", + "https://3.tracker.saveweb.org/", + } +) + +type Tracker struct { + API_BASE string + API_VERSION string + client_version string + project_id string + archivist string + http_client *http.Client + ping_client *http.Client + + _claim_wait_lock sync.Mutex + + __project *Project + __project_last_fetched time.Time // for cache __project + + __background_ping bool + __background_fetch_proj bool + + __gzPool sync.Pool +} + +func _is_safe_string(s string) bool { + match, _ := regexp.MatchString(`[^a-zA-Z0-9_\\-]`, s) + return !match +} + +func GetPing(ctx context.Context, pingClient *http.Client, node string) (*time.Duration, *http.Response, error) { + start := time.Now() + // resp, err := pingClient.Get(node + "ping") + req, err := http.NewRequestWithContext(ctx, "GET", node+"ping", nil) + if err != nil { + return nil, nil, err + } + resp, err := pingClient.Do(req) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + elapsed := time.Since(start) + if resp.StatusCode != 200 { + fmt.Println("status code not 200") + return nil, resp, errors.New("status code not 200") + } + return &elapsed, resp, nil +} + +func (t *Tracker) SelectBestTracker() *Tracker { + fmt.Println("[client->trackers] SelectBestTracker...") + type result struct { + node string + elapsed time.Duration + err error + } + + results := make(chan result, len(TRACKER_NODES)) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, node := range TRACKER_NODES { + go func(ctx context.Context, node string) { + elapsed, resp, err := GetPing(ctx, t.ping_client, node) + if err != nil { + if ctx.Err() == context.Canceled { + return + } + fmt.Println("ping error: ", err) + results <- result{node: node, elapsed: time.Duration(math.MaxInt64), err: err} + } else { + results <- result{node: node, elapsed: *elapsed, err: nil} + fmt.Println("node: ", node, "elapsed: ", elapsed, "protocol: ", resp.Proto) + } + }(ctx, node) + } + + best_node := "" + var best_elapsed time.Duration = time.Duration(math.MaxInt64) + + for range TRACKER_NODES { + res := <-results + if res.err == nil && res.elapsed < best_elapsed { + best_elapsed = res.elapsed + best_node = res.node + break // 一个可用的节点就是最快的 + } + } + + cancel() + + if best_node == "" && t.API_BASE == "" { + panic("no tracker available") + } else if best_node != "" { + t.API_BASE = best_node + fmt.Println("best_node: ", best_node, "best_elapsed: ", best_elapsed) + fmt.Println("=====================================") + } + + return t +} + +func (t *Tracker) StartSelectTrackerBackground() *Tracker { + if t.__background_ping { + panic("another StartSelectTrackerBackground is running") + } + t.__background_ping = true + go func() { + for { + fmt.Println("[client->trackers] SelectBestTrackerBackground...") + t.SelectBestTracker() + time.Sleep(5 * time.Minute) + } + }() + return t +} +func GetTracker(project_id string, client_version string, archivist string) *Tracker { + t := &Tracker{ + API_VERSION: "v1", + ping_client: &http.Client{ + Timeout: 10 * time.Second, + }, + project_id: project_id, + http_client: &http.Client{ + Timeout: 120 * time.Second, + }, + client_version: client_version, + archivist: archivist, + __gzPool: sync.Pool{ + New: func() interface{} { + gz, err := gzip.NewWriterLevel(nil, gzip.BestCompression) + if err != nil { + panic(err) + } + return gz + }, + }, + } + return t +} diff --git a/src/saveweb_tracker/tracker.py b/src/saveweb_tracker/tracker.py new file mode 100644 index 0000000..f4b6041 --- /dev/null +++ b/src/saveweb_tracker/tracker.py @@ -0,0 +1,421 @@ +import asyncio +import copy +from dataclasses import dataclass +import gzip +import logging +import os +import re +import json +import time +from typing import Any, Dict, Optional, Union + +import httpx +from httpx._content import encode_urlencoded_data + +from .task import Task + +FORCE_TRACKER_NODE = os.getenv("FORCE_TRACKER_NODE") + + +logger = logging.getLogger(__name__) + +@dataclass +class ProjectMeta: + identifier: str + slug: str + icon: str + deadline: str +@dataclass +class ProjectStatus: + public: bool + paused: bool +@dataclass +class ProjectClient: + version: str + claim_task_delay: float + """ 客户端在 claim_task 之后,多久之后才能再次 claim_task (软限制)""" +@dataclass +class ProjectMongodb: + db_name: str + item_collection: str + queue_collection: str + custom_doc_id_name: str + """ TODO: 未来弃用 """ +@dataclass +class Project: + meta: ProjectMeta + status: ProjectStatus + client: ProjectClient + mongodb: ProjectMongodb + + def __init__(self, meta: dict, status: dict, client: dict, mongodb: dict): + self.meta = ProjectMeta(**meta) + self.status = ProjectStatus(**status) + self.client = ProjectClient(**client) + self.mongodb = ProjectMongodb(**mongodb) + + +TRACKER_NODES = [ + "http://localhost:8080/", + "https://0.tracker.saveweb.org/", + "https://1.tracker.saveweb.org/", + "https://2.tracker.saveweb.org/", + "https://3.tracker.saveweb.org/", +] + +class ClientVersionOutdatedError(Exception): + pass + +class Tracker: + API_BASE: str = TRACKER_NODES[0] + API_VERSION = "v1" + client_version: str + project_id: str + """ [^a-zA-Z0-9_\\-] """ + archivist: str + """ [^a-zA-Z0-9_\\-] """ + sync_session: httpx.Client + session: httpx.AsyncClient + + _claim_wait_lock = asyncio.Lock() + + __project: Project|None = None + __project_last_fetched: float = 0 + + __last_task_claimed_at: float = 0 + + def _is_safe_string(self, s: str): + r = re.compile(r"[^a-zA-Z0-9_\\-]") + return not r.search(s) + + def __init__(self, project_id: str, client_version: str, archivist: str, session: httpx.AsyncClient|None = None): + """ + raise: ClientVersionOutdatedError + """ + assert self._is_safe_string(project_id) and self._is_safe_string(archivist), "[^a-zA-Z0-9_\\-]" + self.project_id = project_id + self.client_version = client_version + self.archivist = archivist + + sync_transport = httpx.HTTPTransport(retries=3, http2=True, http1=True) + transport = httpx.AsyncHTTPTransport(retries=3, http1=True, http2=True) + + self.sync_session = httpx.Client(http2=True, http1=True, timeout=120, transport=sync_transport) + self.session = httpx.AsyncClient(http2=True, http1=True, timeout=120, transport=transport) if session is None else session + + self.select_best_tracker() + + if self.project.client.version != self.client_version: + raise ClientVersionOutdatedError(f"client_version mismatch, please upgrade your client to {self.project.client.version}") + + print(f"[tracker] Hello, {self.archivist}!") + print(f"[tracker] Project: {self.project}") + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self.close() + await self.aclose() + + def close(self): + self.sync_session.close() + + async def aclose(self): + await self.session.aclose() + + @property + def project(self): + """ + return: Project. + return __project cache if it's not expired (60s) + """ + if self.__project is None or time.time() - self.__project_last_fetched > 60: + try: + self.__project = self.fetch_project(timeout=10) + except Exception as e: + print(f"[client->tracker] fetch_project failed. {e}") + if self.__project is None: + raise e + self.__project_last_fetched = time.time() + assert self.__project is not None + return self.__project + + def select_best_tracker(self): + """ + select best tracker node by latency + if FORCE_TRACKER_NODE is set, use it directly + """ + if FORCE_TRACKER_NODE: + self.API_BASE = FORCE_TRACKER_NODE + print(f"FORCE_TRACKER_NODE: {self.API_BASE}") + return + + result = [] # [(node, latency_score_float)] + print("[client->trackers] select_best_tracker...") + # for node in TRACKER_NODES: + # try: + # self.sync_session.get(node + 'ping', timeout=0.05) # DNS preload, dirty hack + # except Exception: + # pass + + for node in TRACKER_NODES: + start = time.time() + print(f'--- {node} ---') + try: + print("[client->tracker] \tping tracker...") + r = self.sync_session.get(node + 'ping', timeout=5) + r.raise_for_status() + ping_time = time.time() - start + print(f"[client<-tracker] \tping tracker OK. {ping_time:.2f}s ({r.http_version})") + + print("[client->tracker->MongoDB]\tping MongoDB...") + r = self.sync_session.get(node + 'ping_mongodb', timeout=8) + print(r.text) # {"NearestEalapsedTime":236,"PrimaryEalapsedTime":298,"message":"MongoDB is up"}. + r.raise_for_status() # 500 if MongoDB is down + PrimaryEalapsedTime: int = r.json()["PrimaryEalapsedTime"] + PrimaryEalapsedTimeSec = PrimaryEalapsedTime / 1000 + print(f"[client<-tracker<-MongoDB]\tping MongoDB OK. real latency: {ping_time+PrimaryEalapsedTimeSec:.2f}s") + result.append((node, (ping_time*3)+(PrimaryEalapsedTimeSec*0.3))) # DB latency is not very impactful + except Exception as e: + print(f"[client->tracker \tping failed. {e}") + result.append((node, float('inf'))) + result.sort(key=lambda x: x[1]) + self.API_BASE = result[0][0] + print("===============================") + print(f"tracker selected: {self.API_BASE}") + print("===============================") + + def fetch_project(self, timeout: int = 10): + """ + return: Project, deep copy + also refresh __project cache + """ + logger.info(f"[client->tracker] fetch_project... {self.project_id}") + r = self.sync_session.post(self.API_BASE + self.API_VERSION + '/project/' + self.project_id, timeout=timeout) + r.raise_for_status() + proj = Project(**r.json()) + self.__project = copy.deepcopy(proj) + self.__project_last_fetched = time.time() + logger.info(f"[client<-tracker] fetch_project. {self.project_id}") + return proj + + def get_projects(self): + logger.info(f"[client->tracker] get_projects.. {self.project_id}") + r = self.sync_session.post(self.API_BASE + self.API_VERSION + '/projects') + r.raise_for_status() + # print("[client<-tracker] get_projects. OK") + logger.info(f"[client<-tracker] get_projects. {self.project_id}") + return [Project(**p) for p in r.json()] + + + def _before_claim_task(self, with_delay: bool = True): + if with_delay: + if sleep_need := self.project.client.claim_task_delay - (time.time() - self.__last_task_claimed_at): + if sleep_need > 0: + # print(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}") + logger.info(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}") + time.sleep(sleep_need) + elif sleep_need < 0: + # print(f"[tracker] you are {sleep_need:.2f}s late, Qos: {self.project.client.claim_task_delay}") + logger.info(f"[tracker] you are {sleep_need:.2f}s late, Qos: {self.project.client.claim_task_delay}") + + async def _before_claim_task_async(self, with_delay: bool = True): + if with_delay: + if sleep_need := self.project.client.claim_task_delay - (time.time() - self.__last_task_claimed_at): + if sleep_need > 0: + # print(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}") + logger.info(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}") + await self._claim_wait_lock.acquire() + await asyncio.sleep(sleep_need) + self._claim_wait_lock.release() + elif sleep_need < 0: + logger.info(f"[tracker] you are {sleep_need:.2f}s late, Qos: {self.project.client.claim_task_delay}") + + def claim_task(self, with_delay: bool = True): + self._before_claim_task(with_delay) + + logger.info("[client->tracker] claim_task") + start = time.time() + r = self.sync_session.post(f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/claim_task') + return self._after_claim_task(r, start) + + async def claim_task_async(self, with_delay: bool = True): + """ + raise: ClientVersionOutdatedError + """ + await self._before_claim_task_async(with_delay) + + # print("[client->tracker] claim_task") + logger.info("[client->tracker] claim_task") + start = time.time() + r = await self.session.post(f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/claim_task') + return self._after_claim_task(r, start) + + def _after_claim_task(self, r: httpx.Response, start: float): + """ + raise: ClientVersionOutdatedError + """ + self.__last_task_claimed_at = time.time() + if r.status_code == 404: + # print(f'[client<-tracker] claim_task. (time cost: {time.time() - start:.2f}s):', r.text) + logger.info(f'[client<-tracker] claim_task. (time cost: {time.time() - start:.2f}s):', r.text) + return None # No tasks available + if r.status_code == 200: + r_json = r.json() + # print(f'[client<-tracker] claim_task. OK (time cost: {time.time() - start:.2f}s):', r_json) + logger.debug(f'[client<-tracker] claim_task. OK (time cost: {time.time() - start:.2f}s):', r_json) + task = Task(**r_json) + return task + if r.status_code == 400: + r_json = r.json() + if 'error' in r_json and r_json['error'] == 'Client version not supported': + raise ClientVersionOutdatedError(r.text) + raise Exception(r.status_code, r.text) + + def _before_update_task(self, task_id: Union[str, int], status: str): + assert isinstance(task_id, (str, int)) + assert isinstance(status, str) + # print(f"[client->tracker] update_task task({task_id}) to status({status})") + logger.info(f"[client->tracker] update_task task({task_id}) to status({status})") + + def update_task(self, task_id: Union[str, int], status: str): + """ + task_id: 必须明确传入的是 int 还是 str + status: 任务状态 + """ + self._before_update_task(task_id, status) + + r = self.sync_session.post( + f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/update_task/{task_id}', + data={ + 'status': status, + 'task_id_type': 'int' if isinstance(task_id, int) else 'str' + }) + return self._after_update_task(r, task_id) + + async def update_task_async(self, task_id: Union[str, int], status: str): + """ + task_id: 必须明确传入的是 int 还是 str + status: 任务状态 + + raise: ClientVersionOutdatedError + """ + self._before_update_task(task_id, status) + + r = await self.session.post( + f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/update_task/{task_id}', + data={ + 'status': status, + 'task_id_type': 'int' if isinstance(task_id, int) else 'str' + }) + return self._after_update_task(r, task_id) + + def _after_update_task(self, r: httpx.Response, task_id: Union[str, int])->Dict[str, Any]: + if r.status_code == 200: + r_json = r.json() + # print(f'[client<-tracker] update_task task({task_id}). OK:', r_json) + logger.info(f'[client<-tracker] update_task task({task_id}). OK:', r_json) + return r_json + if r.status_code == 400: + r_json = r.json() + if 'error' in r_json and r_json['error'] == 'Client version not supported': + raise ClientVersionOutdatedError(r.text) + raise Exception(r.text) + + def _before_insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]], payload: Optional[dict]): + # item_id_type + if isinstance(item_id, int): + item_id_type = 'int' + elif isinstance(item_id, str): + item_id_type = 'str' + else: + raise ValueError("item_id must be int or str") + + if item_status is None: + status_type = "None" + elif isinstance(item_status, int): + status_type = "int" + elif isinstance(item_status, str): + status_type = "str" + else: + raise ValueError("status must be int, str or None") + payload_json_str = json.dumps(payload, ensure_ascii=False, separators=(',', ':')) + # print(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}") + logger.info(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}") + + return item_id_type, status_type, payload_json_str + + def insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None): + """ + item_id: 必须明确传入的是 int 还是 str + item_status: item 状态,而不是 task 状态。可用于标记一些被删除、被隐藏的 item 之类的。就不用添加一堆错误响应到 payload 里了。 + payload: 可以传入任意的可转为 ext-json 的对象 (包括 None) + """ + item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload) + + r = self.sync_session.post( + f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}', + data={ + 'item_id_type': item_id_type, + 'item_status': item_status, + 'item_status_type': status_type, + 'payload': payload_json_str + }) + return self._after_insert_item(r, item_id) + + async def insert_item_async(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None): + """ + item_id: 必须明确传入的是 int 还是 str + item_status: item 状态,而不是 task 状态。可用于标记一些被删除、被隐藏的 item 之类的。就不用添加一堆错误响应到 payload 里了。 + payload: 可以传入任意的可转为 ext-json 的对象 (包括 None) + """ + item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload) + + data = { + 'item_id_type': item_id_type, + 'item_status': item_status, + 'item_status_type': status_type, + 'payload': payload_json_str + } + + + _, encoded_data_stream = encode_urlencoded_data(data) + encoded_data = b''.join(encoded_data_stream) + encoded_data_compressed = gzip.compress(encoded_data) + + rate = len(encoded_data_compressed) / len(encoded_data) + + + if rate < 0.95: + # low rate + gzip_headers = { "Content-Encoding": "gzip", "Content-Type": "application/x-www-form-urlencoded"} + logger.debug(f"compressed {len(encoded_data)} -> {len(encoded_data_compressed)}") + with open('debug.gzip', 'wb') as f: + f.write(encoded_data_compressed) + try: + r = await self.session.post( + f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}', + content=encoded_data_compressed, headers=gzip_headers) + return self._after_insert_item(r, item_id) + except Exception as e: + print(f"insert fail, retry (without gzip compression)... {e}") + + # high rate or retry without compression + r = await self.session.post( + f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}', + data=data) + return self._after_insert_item(r, item_id) + + raise RuntimeError("???? (^^)") + + def _after_insert_item(self, r: httpx.Response, item_id: Union[str, int])->Dict[str, Any]: + """ + return r_json if r.status_code == 200 else raise Exception(r.text) + """ + if r.status_code == 200: + r_json = r.json() + # print(f'[client<-tracker] insert_item item({item_id}). OK:', r_json) + logger.info(f'[client<-tracker] insert_item item({item_id}). OK:', r_json) + return r_json + raise Exception(r.text) diff --git a/src/saveweb_tracker/tracker_project.go b/src/saveweb_tracker/tracker_project.go new file mode 100644 index 0000000..b63323d --- /dev/null +++ b/src/saveweb_tracker/tracker_project.go @@ -0,0 +1,70 @@ +package savewebtracker + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "time" +) + +func (t *Tracker) Project() (proj Project) { + if time.Since(t.__project_last_fetched) < 3*time.Minute { + return *t.__project + } + + t.StartFetchProjectBackground() + for t.__project == nil { + fmt.Println("[client->tracker] waiting for project... ", t.project_id) + time.Sleep(5 * time.Second) + } + return t.Project() +} + +func (t *Tracker) StartFetchProjectBackground() { + if t.__background_fetch_proj { + return + } + t.__background_fetch_proj = true + go func() { + for { + go t.FetchProject(20 * time.Second) + time.Sleep(1 * time.Minute) + } + }() +} + +func (t *Tracker) FetchProject(timeout time.Duration) (proj *Project, err error) { + fmt.Println("[client->tracker] fetch_project... ", t.project_id) + + ctx, cancel := context.WithTimeout(context.TODO(), timeout) + time.AfterFunc(timeout, func() { + cancel() + }) + + req, err := http.NewRequestWithContext(ctx, "POST", t.API_BASE+t.API_VERSION+"/project/"+t.project_id, nil) + if err != nil { + log.Print(err) + return nil, err + } + r, err := t.http_client.Do(req) + if err != nil { + log.Print(err) + return nil, err + } + defer r.Body.Close() + if r.StatusCode != 200 { + return nil, errors.New("status code not 200") + } + proj = &Project{} + err = json.NewDecoder(r.Body).Decode(proj) + if err != nil { + return nil, err + } + t.__project = proj + t.__project_last_fetched = time.Now() + fmt.Println("[client<-tracker] fetch_project. ", t.project_id) + return proj, nil +} diff --git a/src/saveweb_tracker/tracker_task.go b/src/saveweb_tracker/tracker_task.go new file mode 100644 index 0000000..89305d6 --- /dev/null +++ b/src/saveweb_tracker/tracker_task.go @@ -0,0 +1,303 @@ +package savewebtracker + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +type ObjectID string +type DatetimeUTC string + +func (s Status) Validate() bool { + for _, v := range LIST_STATUS { + if v == s { + return true + } + } + return false +} + +func (d DatetimeUTC) GetDatetime() time.Time { + // 2024-06-06T15:45:00.008Z + t, err := time.Parse(time.RFC3339, string(d)) + if err != nil { + panic(err) + } + return t +} + +type Task struct { + Obj_id ObjectID `json:"_id"` + Id string `json:"-"` + Id_raw json.RawMessage `json:"id"` + Id_type string `json:"-"` // str or int + Status Status `json:"status"` + Archivist string `json:"archivist"` // Optional + + Claimed_at DatetimeUTC `json:"claimed_at"` // Optional + Updated_at DatetimeUTC `json:"updated_at"` // Optional +} + +var ( + ErrorClientVersionOutdated = errors.New("client version outdated") + ENABLE_GZIP = true +) + +func (t *Tracker) ClaimTask(with_delay bool) *Task { + if with_delay { + t._claim_wait_lock.Lock() + time.Sleep(time.Duration(t.Project().Client.ClaimTaskDelay) * time.Second) + t._claim_wait_lock.Unlock() + } + resp, err := t.http_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/claim_task", "", nil) + if err != nil { + panic(err) + } + return _after_claim_task(resp) +} + +// 无任务返回 nil +func _after_claim_task(r *http.Response) *Task { + if r.StatusCode == 404 { + return nil + } + if r.StatusCode == 200 { + task := Task{} + err := json.NewDecoder(r.Body).Decode(&task) + if err != nil { + panic(err) + } + + var idInt int + var idString string + if err := json.Unmarshal(task.Id_raw, &idInt); err == nil { + idString = fmt.Sprintf("%d", idInt) + task.Id = idString + task.Id_type = "int" + } else if err := json.Unmarshal(task.Id_raw, &idString); err == nil { + task.Id = idString + task.Id_type = "str" + } else { + panic(err) + } + + return &task + } + + BodyBytes, _ := io.ReadAll(r.Body) + panic(string(BodyBytes)) +} + +func (t *Tracker) UpdateTask(task_id string, id_type string, to_status Status) string { + + postData := url.Values{} + postData.Set("status", string(to_status)) + postData.Set("task_id_type", id_type) + + if !to_status.Validate() { + fmt.Println("invalid status, to_status:", to_status) + panic("invalid status") + } + + resp, err := t.http_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/update_task/"+task_id, "application/x-www-form-urlencoded", strings.NewReader(postData.Encode())) + if err != nil { + panic(err) + } + return _after_update_task(resp) +} + +func _after_update_task(r *http.Response) string { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + text := string(bodyBytes) + + if r.StatusCode == 200 { + return text + } + if r.StatusCode == 400 { + panic(text) + } + + fmt.Println(r.StatusCode, r.Request.URL, text) + panic(text) +} + +// def _before_insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]], payload: Optional[dict]): +// # item_id_type +// if isinstance(item_id, int): +// item_id_type = 'int' +// elif isinstance(item_id, str): +// item_id_type = 'str' +// else: +// raise ValueError("item_id must be int or str") + +// if item_status is None: +// status_type = "None" +// elif isinstance(item_status, int): +// status_type = "int" +// elif isinstance(item_status, str): +// status_type = "str" +// else: +// raise ValueError("status must be int, str or None") +// payload_json_str = json.dumps(payload, ensure_ascii=False, separators=(',', ':')) +// # print(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}") +// logger.info(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}") + +// return item_id_type, status_type, payload_json_str + +func (t *Tracker) InsertItem(task Task, item_status string, status_type string, payload string) string { + if status_type != "int" && status_type != "str" && status_type != "None" { + panic("status must be int, str or None") + } + req_url := t.API_BASE + t.API_VERSION + "/project/" + t.project_id + "/" + t.client_version + "/" + t.archivist + "/insert_item/" + task.Id + + data := url.Values{} + data.Set("item_id_type", task.Id_type) + data.Set("item_status", item_status) + data.Set("item_status_type", status_type) + data.Set("payload", payload) + encodedData := data.Encode() + + gzBuf := &bytes.Buffer{} + + gz := t.__gzPool.Get().(*gzip.Writer) + defer t.__gzPool.Put(gz) + defer gz.Reset(io.Discard) + defer gz.Close() + + gz.Reset(gzBuf) + if _, err := gz.Write([]byte(encodedData)); err != nil { + panic(err) + } + if err := gz.Flush(); err != nil { + panic(err) + } + gz.Close() + + len_encodedData := len([]byte(encodedData)) + // fmt.Printf("compressed %d -> %d \n", len_encodedData, gzBuf.Len()) + req := &http.Request{} + var err error + + if ENABLE_GZIP && float64(gzBuf.Len())/float64(len_encodedData) < 0.95 { // good compression rate + fmt.Println("using gzip...", gzBuf.Len()) + // req, err = http.NewRequest("POST", req_url, gzBuf) + req, err = http.NewRequest("POST", req_url, gzBuf) + if err != nil { + panic(err) + } + req.Header.Set("Content-Encoding", "gzip") + } else { + req, err = http.NewRequest("POST", req_url, strings.NewReader(encodedData)) + if err != nil { + panic(err) + } + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "*/*") + + fmt.Println(req.Header) + + resp, err := t.http_client.Do(req) + if err != nil { + panic(err) + } + return _after_insert_item(resp) +} + +func _after_insert_item(r *http.Response) string { + defer r.Body.Close() + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + text := string(bodyBytes) + + if r.StatusCode == 200 { + return text + } + + fmt.Println(r.StatusCode, r.Request.URL, text) + panic(text) +} + +// def insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None): +// """ +// item_id: 必须明确传入的是 int 还是 str +// item_status: item 状态,而不是 task 状态。可用于标记一些被删除、被隐藏的 item 之类的。就不用添加一堆错误响应到 payload 里了。 +// payload: 可以传入任意的可转为 ext-json 的对象 (包括 None) +// """ +// item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload) + +// r = self.sync_session.post( +// f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}', +// data={ +// 'item_id_type': item_id_type, +// 'item_status': item_status, +// 'item_status_type': status_type, +// 'payload': payload_json_str +// }) +// return self._after_insert_item(r, item_id) + +// async def insert_item_async(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None): +// """ +// item_id: 必须明确传入的是 int 还是 str +// item_status: item 状态,而不是 task 状态。可用于标记一些被删除、被隐藏的 item 之类的。就不用添加一堆错误响应到 payload 里了。 +// payload: 可以传入任意的可转为 ext-json 的对象 (包括 None) +// """ +// item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload) + +// data = { +// 'item_id_type': item_id_type, +// 'item_status': item_status, +// 'item_status_type': status_type, +// 'payload': payload_json_str +// } + +// _, encoded_data_stream = encode_urlencoded_data(data) +// encoded_data = b''.join(encoded_data_stream) +// encoded_data_compressed = gzip.compress(encoded_data) + +// rate = len(encoded_data_compressed) / len(encoded_data) + +// if rate < 0.95: +// # low rate +// gzip_headers = { "Content-Encoding": "gzip", "Content-Type": "application/x-www-form-urlencoded"} +// logger.debug(f"compressed {len(encoded_data)} -> {len(encoded_data_compressed)}") +// try: +// r = await self.session.post( +// f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}', +// content=encoded_data_compressed, headers=gzip_headers) +// return self._after_insert_item(r, item_id) +// except Exception as e: +// print(f"insert fail, retry (without gzip compression)... {e}") + +// # high rate or retry without compression +// r = await self.session.post( +// f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}', +// data=data) +// return self._after_insert_item(r, item_id) + +// raise RuntimeError("???? (^^)") + +// def _after_insert_item(self, r: httpx.Response, item_id: Union[str, int])->Dict[str, Any]: +// """ +// return r_json if r.status_code == 200 else raise Exception(r.text) +// """ +// if r.status_code == 200: +// r_json = r.json() +// # print(f'[client<-tracker] insert_item item({item_id}). OK:', r_json) +// logger.info(f'[client<-tracker] insert_item item({item_id}). OK:', r_json) +// return r_json +// raise Exception(r.text) diff --git a/src/saveweb_tracker/tracker_task_test.go b/src/saveweb_tracker/tracker_task_test.go new file mode 100644 index 0000000..6859ade --- /dev/null +++ b/src/saveweb_tracker/tracker_task_test.go @@ -0,0 +1,84 @@ +package savewebtracker + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCliamTask(t *testing.T) { + tracker := GetDefaultTracker().SelectBestTracker() + task := tracker.ClaimTask(false) + if task == nil { + t.Error("task is nil") + } + fmt.Println(task) + json_str, _ := json.Marshal(&task) + fmt.Println(string(json_str)) + + if task != nil { + fmt.Println( + "Id_type", task.Id_type, '\n', + "Updated_at", task.Updated_at.GetDatetime(), '\n', + "Claimed_at", task.Claimed_at.GetDatetime(), '\n', + ) + } +} + +func TestCliamAndUpdateTask(t *testing.T) { + tracker := GetDefaultTracker().SelectBestTracker() + task := tracker.ClaimTask(false) + if task == nil { + t.Error("task is nil") + return + } + text := tracker.UpdateTask(task.Id, task.Id_type, StatusDONE) + fmt.Println("task_id:", task.Id) + fmt.Println(text) +} + +func TestBytes(t *testing.T) { + b := "京哈随章长而窄喝着和" + assert.Equal(t, b, string([]byte(b))) +} + +func TestClaimAndUpdateAndInsertTask(t *testing.T) { + tracker := GetDefaultTracker().SelectBestTracker() + task := tracker.ClaimTask(false) + if task == nil { + t.Error("task is nil") + return + } + text := tracker.UpdateTask(task.Id, task.Id_type, StatusDONE) + fmt.Println("task_id:", task.Id, text) + + payload := map[string]interface{}{ + "kjasd": "111111111122222111111111122222111111111122222111111111122222", + "abasr": "111111111122222111111111122222111111111122222111111111122222", + "list": []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + "dict": map[string]interface{}{"a": 1, "b": 2, "c": 3}, + "nested": map[string]interface{}{ + "list": []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + "nested": map[string]interface{}{ + "list": []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + } + payload_json_str, _ := json.Marshal(payload) + text = tracker.InsertItem(*task, "null", "None", string(payload_json_str)) + fmt.Println("inserted task_id:", task.Id, text) +} + +func TestClaimWithOldVersion(t *testing.T) { + tracker := GetDefaultTracker().SelectBestTracker() + tracker.client_version = "0.0" + + assert.Panics(t, func() { + task := tracker.ClaimTask(false) + if task == nil { + t.Error("task is nil") + } + }) +} diff --git a/src/saveweb_tracker/tracker_test.go b/src/saveweb_tracker/tracker_test.go new file mode 100644 index 0000000..ab1b070 --- /dev/null +++ b/src/saveweb_tracker/tracker_test.go @@ -0,0 +1,36 @@ +package savewebtracker + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func GetDefaultTracker() *Tracker { + return GetTracker("test", "1.0", "archivist-test") +} + +func TestIsSafeString(t *testing.T) { + assert.True(t, _is_safe_string("abc")) + assert.True(t, _is_safe_string("abc-gg36c3rc-5vw_")) + assert.True(t, _is_safe_string("")) + assert.False(t, _is_safe_string("abc-5vw_!")) + assert.False(t, _is_safe_string("asdasd/..")) +} + +func TestGetTracker(t *testing.T) { + tracker := GetDefaultTracker().SelectBestTracker() + assert.NotEmpty(t, tracker.API_BASE) +} + +func TestGetProject(t *testing.T) { + tracker := GetDefaultTracker().SelectBestTracker() + assert.False(t, tracker.__background_fetch_proj) + proj := tracker.Project() + assert.True(t, tracker.__background_fetch_proj) + assert.NotNil(t, proj) + fmt.Println(proj) + same_proj := tracker.Project() + assert.Equal(t, proj, same_proj) +} diff --git a/tests/tracker_test.py b/tests/tracker_test.py new file mode 100644 index 0000000..9904bf9 --- /dev/null +++ b/tests/tracker_test.py @@ -0,0 +1,18 @@ +import pytest +import saveweb_tracker +import saveweb_tracker.tracker + +@pytest.mark.asyncio +async def test_tracker(): + async with saveweb_tracker.tracker.Tracker("test", "1.0", "test-python") as tracker: + task = await tracker.claim_task_async(with_delay=False) + print(task) + assert task is not None + text = await tracker.update_task_async(task_id=task.id, status="DONE") + print(text) + text = await tracker.insert_item_async( + item_id=task.id, item_status=None, payload={ + "kjasd": "111111111122222111111111122222111111111122222111111111122222111111111122222111111111122222111111111122222111111111122222", + } + ) + print(text) \ No newline at end of file