first version

This commit is contained in:
yzqzss 2024-06-07 17:34:17 +08:00
parent cc96ca8f6e
commit 67c2c0dfbc
10 changed files with 1214 additions and 0 deletions

View File

@ -0,0 +1,17 @@
import os
def new_archivist():
print("zh: 第一次运行,请输入可以唯一标识您节点的字符串,例如 alice-aws-114。合法字符字母、数字、-、_")
print("en: First run, please input a string that can uniquely identify your node (for example: bob-gcloud-514). (Legal characters: letters, numbers, -, _)")
arch = input("ARCHIVIST: ")
with open("ARCHIVIST.conf", "w") as f:
f.write(arch)
return get_archivist()
def get_archivist():
if arch := os.getenv("ARCHIVIST", ""):
return arch
if os.path.exists("ARCHIVIST.conf"):
with open("ARCHIVIST.conf", "r") as f:
return f.read().splitlines()[0].strip()
return ""

View File

@ -0,0 +1,33 @@
package savewebtracker
// class Status:
// TODO = "TODO"
// PROCESSING = "PROCESSING"
// DONE = "DONE"
// EMPTY = "EMPTY" # 给 item 用的
// # TIMEOUT = "TIMEOUT" # 一直 PROCESSING 的任务,超时
// FAIL = "FAIL"
// FEZZ = "FEZZ"
// """ 特殊: 任务冻结 """
const (
StatusTODO = "TODO"
StatusPROCESSING = "PROCESSING"
StatusDONE = "DONE"
StatusEMPTY = "EMPTY"
StatusFAIL = "FAIL"
StatusFEZZ = "FEZZ"
)
type Status string
var LIST_STATUS = []Status{
StatusTODO,
StatusPROCESSING,
StatusDONE,
StatusEMPTY,
StatusFAIL,
StatusFEZZ,
}

View File

@ -0,0 +1,40 @@
from typing import Optional
from dataclasses import dataclass
from datetime import datetime
class Status:
TODO = "TODO"
PROCESSING = "PROCESSING"
DONE = "DONE"
EMPTY = "EMPTY" # 给 item 用的
# TIMEOUT = "TIMEOUT" # 一直 PROCESSING 的任务,超时
FAIL = "FAIL"
FEZZ = "FEZZ"
""" 特殊: 任务冻结 """
@dataclass
class Task:
_id: str
""" ObjectID """
id: int
status: Status
archivist: str
claimed_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
def __post_init__(self):
assert self.status in Status.__dict__.values()
def __repr__(self):
return f"Task({self.id}, status={self.status})"
def __init__(self, _id, id, status, archivist, claimed_at, updated_at):
self._id = _id
self.id = id
self.status = status
self.archivist = archivist
self.claimed_at = claimed_at
self.updated_at = updated_at

View File

@ -0,0 +1,192 @@
package savewebtracker
import (
"compress/gzip"
"context"
"errors"
"fmt"
"math"
"net/http"
"regexp"
"sync"
"time"
)
type ProjectMeta struct {
Identifier string `json:"identifier"`
Slug string `json:"slug"` // 一句话项目说明
Icon string `json:"icon"` // 图标 URL
Deadline string `json:"deadline"` // 截止日期,没有格式要求
}
type ProjectStatus struct {
Public bool `json:"public"` // 是否公开。(不在 /projects 列表中列出,但 /project 可以请求)
Paused bool `json:"paused"` // 是否暂停。暂停后不再接受新的 claim_task 请求。但是仍可 update_task 和 insert_item
}
type ProjectClient struct {
Version string `json:"version"` // 推荐 "大版本.小版本"。版本不对会拒绝各种请求。
ClaimTaskDelay float32 `json:"claim_task_delay"` // claim_task 之后多久才能再次 claim_task
}
type ProjectMongodb struct {
DbName string `json:"db_name"`
ItemCollection string `json:"item_collection"`
QueueCollection string `json:"queue_collection"`
CustomDocIDName string `json:"custom_doc_id_name"` // CustomIDName 目前只是为了兼容 lowapk_v2 存档项目而留的字段。它用的 feed_id 而不是 id
// 计划未来等转换好数据库 scheme 之后就弃用。
}
type Project struct {
Meta ProjectMeta `json:"meta"`
Status ProjectStatus `json:"status"`
Client ProjectClient `json:"client"`
Mongodb ProjectMongodb `json:"mongodb"`
}
var (
TRACKER_NODES = []string{
"http://localhost:8080/",
"https://0.tracker.saveweb.org/",
"https://1.tracker.saveweb.org/",
"https://2.tracker.saveweb.org/",
"https://3.tracker.saveweb.org/",
}
)
type Tracker struct {
API_BASE string
API_VERSION string
client_version string
project_id string
archivist string
http_client *http.Client
ping_client *http.Client
_claim_wait_lock sync.Mutex
__project *Project
__project_last_fetched time.Time // for cache __project
__background_ping bool
__background_fetch_proj bool
__gzPool sync.Pool
}
func _is_safe_string(s string) bool {
match, _ := regexp.MatchString(`[^a-zA-Z0-9_\\-]`, s)
return !match
}
func GetPing(ctx context.Context, pingClient *http.Client, node string) (*time.Duration, *http.Response, error) {
start := time.Now()
// resp, err := pingClient.Get(node + "ping")
req, err := http.NewRequestWithContext(ctx, "GET", node+"ping", nil)
if err != nil {
return nil, nil, err
}
resp, err := pingClient.Do(req)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
elapsed := time.Since(start)
if resp.StatusCode != 200 {
fmt.Println("status code not 200")
return nil, resp, errors.New("status code not 200")
}
return &elapsed, resp, nil
}
func (t *Tracker) SelectBestTracker() *Tracker {
fmt.Println("[client->trackers] SelectBestTracker...")
type result struct {
node string
elapsed time.Duration
err error
}
results := make(chan result, len(TRACKER_NODES))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, node := range TRACKER_NODES {
go func(ctx context.Context, node string) {
elapsed, resp, err := GetPing(ctx, t.ping_client, node)
if err != nil {
if ctx.Err() == context.Canceled {
return
}
fmt.Println("ping error: ", err)
results <- result{node: node, elapsed: time.Duration(math.MaxInt64), err: err}
} else {
results <- result{node: node, elapsed: *elapsed, err: nil}
fmt.Println("node: ", node, "elapsed: ", elapsed, "protocol: ", resp.Proto)
}
}(ctx, node)
}
best_node := ""
var best_elapsed time.Duration = time.Duration(math.MaxInt64)
for range TRACKER_NODES {
res := <-results
if res.err == nil && res.elapsed < best_elapsed {
best_elapsed = res.elapsed
best_node = res.node
break // 一个可用的节点就是最快的
}
}
cancel()
if best_node == "" && t.API_BASE == "" {
panic("no tracker available")
} else if best_node != "" {
t.API_BASE = best_node
fmt.Println("best_node: ", best_node, "best_elapsed: ", best_elapsed)
fmt.Println("=====================================")
}
return t
}
func (t *Tracker) StartSelectTrackerBackground() *Tracker {
if t.__background_ping {
panic("another StartSelectTrackerBackground is running")
}
t.__background_ping = true
go func() {
for {
fmt.Println("[client->trackers] SelectBestTrackerBackground...")
t.SelectBestTracker()
time.Sleep(5 * time.Minute)
}
}()
return t
}
func GetTracker(project_id string, client_version string, archivist string) *Tracker {
t := &Tracker{
API_VERSION: "v1",
ping_client: &http.Client{
Timeout: 10 * time.Second,
},
project_id: project_id,
http_client: &http.Client{
Timeout: 120 * time.Second,
},
client_version: client_version,
archivist: archivist,
__gzPool: sync.Pool{
New: func() interface{} {
gz, err := gzip.NewWriterLevel(nil, gzip.BestCompression)
if err != nil {
panic(err)
}
return gz
},
},
}
return t
}

View File

@ -0,0 +1,421 @@
import asyncio
import copy
from dataclasses import dataclass
import gzip
import logging
import os
import re
import json
import time
from typing import Any, Dict, Optional, Union
import httpx
from httpx._content import encode_urlencoded_data
from .task import Task
FORCE_TRACKER_NODE = os.getenv("FORCE_TRACKER_NODE")
logger = logging.getLogger(__name__)
@dataclass
class ProjectMeta:
identifier: str
slug: str
icon: str
deadline: str
@dataclass
class ProjectStatus:
public: bool
paused: bool
@dataclass
class ProjectClient:
version: str
claim_task_delay: float
""" 客户端在 claim_task 之后,多久之后才能再次 claim_task (软限制)"""
@dataclass
class ProjectMongodb:
db_name: str
item_collection: str
queue_collection: str
custom_doc_id_name: str
""" TODO: 未来弃用 """
@dataclass
class Project:
meta: ProjectMeta
status: ProjectStatus
client: ProjectClient
mongodb: ProjectMongodb
def __init__(self, meta: dict, status: dict, client: dict, mongodb: dict):
self.meta = ProjectMeta(**meta)
self.status = ProjectStatus(**status)
self.client = ProjectClient(**client)
self.mongodb = ProjectMongodb(**mongodb)
TRACKER_NODES = [
"http://localhost:8080/",
"https://0.tracker.saveweb.org/",
"https://1.tracker.saveweb.org/",
"https://2.tracker.saveweb.org/",
"https://3.tracker.saveweb.org/",
]
class ClientVersionOutdatedError(Exception):
pass
class Tracker:
API_BASE: str = TRACKER_NODES[0]
API_VERSION = "v1"
client_version: str
project_id: str
""" [^a-zA-Z0-9_\\-] """
archivist: str
""" [^a-zA-Z0-9_\\-] """
sync_session: httpx.Client
session: httpx.AsyncClient
_claim_wait_lock = asyncio.Lock()
__project: Project|None = None
__project_last_fetched: float = 0
__last_task_claimed_at: float = 0
def _is_safe_string(self, s: str):
r = re.compile(r"[^a-zA-Z0-9_\\-]")
return not r.search(s)
def __init__(self, project_id: str, client_version: str, archivist: str, session: httpx.AsyncClient|None = None):
"""
raise: ClientVersionOutdatedError
"""
assert self._is_safe_string(project_id) and self._is_safe_string(archivist), "[^a-zA-Z0-9_\\-]"
self.project_id = project_id
self.client_version = client_version
self.archivist = archivist
sync_transport = httpx.HTTPTransport(retries=3, http2=True, http1=True)
transport = httpx.AsyncHTTPTransport(retries=3, http1=True, http2=True)
self.sync_session = httpx.Client(http2=True, http1=True, timeout=120, transport=sync_transport)
self.session = httpx.AsyncClient(http2=True, http1=True, timeout=120, transport=transport) if session is None else session
self.select_best_tracker()
if self.project.client.version != self.client_version:
raise ClientVersionOutdatedError(f"client_version mismatch, please upgrade your client to {self.project.client.version}")
print(f"[tracker] Hello, {self.archivist}!")
print(f"[tracker] Project: {self.project}")
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
self.close()
await self.aclose()
def close(self):
self.sync_session.close()
async def aclose(self):
await self.session.aclose()
@property
def project(self):
"""
return: Project.
return __project cache if it's not expired (60s)
"""
if self.__project is None or time.time() - self.__project_last_fetched > 60:
try:
self.__project = self.fetch_project(timeout=10)
except Exception as e:
print(f"[client->tracker] fetch_project failed. {e}")
if self.__project is None:
raise e
self.__project_last_fetched = time.time()
assert self.__project is not None
return self.__project
def select_best_tracker(self):
"""
select best tracker node by latency
if FORCE_TRACKER_NODE is set, use it directly
"""
if FORCE_TRACKER_NODE:
self.API_BASE = FORCE_TRACKER_NODE
print(f"FORCE_TRACKER_NODE: {self.API_BASE}")
return
result = [] # [(node, latency_score_float)]
print("[client->trackers] select_best_tracker...")
# for node in TRACKER_NODES:
# try:
# self.sync_session.get(node + 'ping', timeout=0.05) # DNS preload, dirty hack
# except Exception:
# pass
for node in TRACKER_NODES:
start = time.time()
print(f'--- {node} ---')
try:
print("[client->tracker] \tping tracker...")
r = self.sync_session.get(node + 'ping', timeout=5)
r.raise_for_status()
ping_time = time.time() - start
print(f"[client<-tracker] \tping tracker OK. {ping_time:.2f}s ({r.http_version})")
print("[client->tracker->MongoDB]\tping MongoDB...")
r = self.sync_session.get(node + 'ping_mongodb', timeout=8)
print(r.text) # {"NearestEalapsedTime":236,"PrimaryEalapsedTime":298,"message":"MongoDB is up"}.
r.raise_for_status() # 500 if MongoDB is down
PrimaryEalapsedTime: int = r.json()["PrimaryEalapsedTime"]
PrimaryEalapsedTimeSec = PrimaryEalapsedTime / 1000
print(f"[client<-tracker<-MongoDB]\tping MongoDB OK. real latency: {ping_time+PrimaryEalapsedTimeSec:.2f}s")
result.append((node, (ping_time*3)+(PrimaryEalapsedTimeSec*0.3))) # DB latency is not very impactful
except Exception as e:
print(f"[client->tracker \tping failed. {e}")
result.append((node, float('inf')))
result.sort(key=lambda x: x[1])
self.API_BASE = result[0][0]
print("===============================")
print(f"tracker selected: {self.API_BASE}")
print("===============================")
def fetch_project(self, timeout: int = 10):
"""
return: Project, deep copy
also refresh __project cache
"""
logger.info(f"[client->tracker] fetch_project... {self.project_id}")
r = self.sync_session.post(self.API_BASE + self.API_VERSION + '/project/' + self.project_id, timeout=timeout)
r.raise_for_status()
proj = Project(**r.json())
self.__project = copy.deepcopy(proj)
self.__project_last_fetched = time.time()
logger.info(f"[client<-tracker] fetch_project. {self.project_id}")
return proj
def get_projects(self):
logger.info(f"[client->tracker] get_projects.. {self.project_id}")
r = self.sync_session.post(self.API_BASE + self.API_VERSION + '/projects')
r.raise_for_status()
# print("[client<-tracker] get_projects. OK")
logger.info(f"[client<-tracker] get_projects. {self.project_id}")
return [Project(**p) for p in r.json()]
def _before_claim_task(self, with_delay: bool = True):
if with_delay:
if sleep_need := self.project.client.claim_task_delay - (time.time() - self.__last_task_claimed_at):
if sleep_need > 0:
# print(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}")
logger.info(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}")
time.sleep(sleep_need)
elif sleep_need < 0:
# print(f"[tracker] you are {sleep_need:.2f}s late, Qos: {self.project.client.claim_task_delay}")
logger.info(f"[tracker] you are {sleep_need:.2f}s late, Qos: {self.project.client.claim_task_delay}")
async def _before_claim_task_async(self, with_delay: bool = True):
if with_delay:
if sleep_need := self.project.client.claim_task_delay - (time.time() - self.__last_task_claimed_at):
if sleep_need > 0:
# print(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}")
logger.info(f"[tracker] slow down you {sleep_need:.2f}s, Qos: {self.project.client.claim_task_delay}")
await self._claim_wait_lock.acquire()
await asyncio.sleep(sleep_need)
self._claim_wait_lock.release()
elif sleep_need < 0:
logger.info(f"[tracker] you are {sleep_need:.2f}s late, Qos: {self.project.client.claim_task_delay}")
def claim_task(self, with_delay: bool = True):
self._before_claim_task(with_delay)
logger.info("[client->tracker] claim_task")
start = time.time()
r = self.sync_session.post(f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/claim_task')
return self._after_claim_task(r, start)
async def claim_task_async(self, with_delay: bool = True):
"""
raise: ClientVersionOutdatedError
"""
await self._before_claim_task_async(with_delay)
# print("[client->tracker] claim_task")
logger.info("[client->tracker] claim_task")
start = time.time()
r = await self.session.post(f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/claim_task')
return self._after_claim_task(r, start)
def _after_claim_task(self, r: httpx.Response, start: float):
"""
raise: ClientVersionOutdatedError
"""
self.__last_task_claimed_at = time.time()
if r.status_code == 404:
# print(f'[client<-tracker] claim_task. (time cost: {time.time() - start:.2f}s):', r.text)
logger.info(f'[client<-tracker] claim_task. (time cost: {time.time() - start:.2f}s):', r.text)
return None # No tasks available
if r.status_code == 200:
r_json = r.json()
# print(f'[client<-tracker] claim_task. OK (time cost: {time.time() - start:.2f}s):', r_json)
logger.debug(f'[client<-tracker] claim_task. OK (time cost: {time.time() - start:.2f}s):', r_json)
task = Task(**r_json)
return task
if r.status_code == 400:
r_json = r.json()
if 'error' in r_json and r_json['error'] == 'Client version not supported':
raise ClientVersionOutdatedError(r.text)
raise Exception(r.status_code, r.text)
def _before_update_task(self, task_id: Union[str, int], status: str):
assert isinstance(task_id, (str, int))
assert isinstance(status, str)
# print(f"[client->tracker] update_task task({task_id}) to status({status})")
logger.info(f"[client->tracker] update_task task({task_id}) to status({status})")
def update_task(self, task_id: Union[str, int], status: str):
"""
task_id: 必须明确传入的是 int 还是 str
status: 任务状态
"""
self._before_update_task(task_id, status)
r = self.sync_session.post(
f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/update_task/{task_id}',
data={
'status': status,
'task_id_type': 'int' if isinstance(task_id, int) else 'str'
})
return self._after_update_task(r, task_id)
async def update_task_async(self, task_id: Union[str, int], status: str):
"""
task_id: 必须明确传入的是 int 还是 str
status: 任务状态
raise: ClientVersionOutdatedError
"""
self._before_update_task(task_id, status)
r = await self.session.post(
f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/update_task/{task_id}',
data={
'status': status,
'task_id_type': 'int' if isinstance(task_id, int) else 'str'
})
return self._after_update_task(r, task_id)
def _after_update_task(self, r: httpx.Response, task_id: Union[str, int])->Dict[str, Any]:
if r.status_code == 200:
r_json = r.json()
# print(f'[client<-tracker] update_task task({task_id}). OK:', r_json)
logger.info(f'[client<-tracker] update_task task({task_id}). OK:', r_json)
return r_json
if r.status_code == 400:
r_json = r.json()
if 'error' in r_json and r_json['error'] == 'Client version not supported':
raise ClientVersionOutdatedError(r.text)
raise Exception(r.text)
def _before_insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]], payload: Optional[dict]):
# item_id_type
if isinstance(item_id, int):
item_id_type = 'int'
elif isinstance(item_id, str):
item_id_type = 'str'
else:
raise ValueError("item_id must be int or str")
if item_status is None:
status_type = "None"
elif isinstance(item_status, int):
status_type = "int"
elif isinstance(item_status, str):
status_type = "str"
else:
raise ValueError("status must be int, str or None")
payload_json_str = json.dumps(payload, ensure_ascii=False, separators=(',', ':'))
# print(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}")
logger.info(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}")
return item_id_type, status_type, payload_json_str
def insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None):
"""
item_id: 必须明确传入的是 int 还是 str
item_status: item 状态而不是 task 状态可用于标记一些被删除被隐藏的 item 之类的就不用添加一堆错误响应到 payload 里了
payload: 可以传入任意的可转为 ext-json 的对象 (包括 None)
"""
item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload)
r = self.sync_session.post(
f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}',
data={
'item_id_type': item_id_type,
'item_status': item_status,
'item_status_type': status_type,
'payload': payload_json_str
})
return self._after_insert_item(r, item_id)
async def insert_item_async(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None):
"""
item_id: 必须明确传入的是 int 还是 str
item_status: item 状态而不是 task 状态可用于标记一些被删除被隐藏的 item 之类的就不用添加一堆错误响应到 payload 里了
payload: 可以传入任意的可转为 ext-json 的对象 (包括 None)
"""
item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload)
data = {
'item_id_type': item_id_type,
'item_status': item_status,
'item_status_type': status_type,
'payload': payload_json_str
}
_, encoded_data_stream = encode_urlencoded_data(data)
encoded_data = b''.join(encoded_data_stream)
encoded_data_compressed = gzip.compress(encoded_data)
rate = len(encoded_data_compressed) / len(encoded_data)
if rate < 0.95:
# low rate
gzip_headers = { "Content-Encoding": "gzip", "Content-Type": "application/x-www-form-urlencoded"}
logger.debug(f"compressed {len(encoded_data)} -> {len(encoded_data_compressed)}")
with open('debug.gzip', 'wb') as f:
f.write(encoded_data_compressed)
try:
r = await self.session.post(
f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}',
content=encoded_data_compressed, headers=gzip_headers)
return self._after_insert_item(r, item_id)
except Exception as e:
print(f"insert fail, retry (without gzip compression)... {e}")
# high rate or retry without compression
r = await self.session.post(
f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}',
data=data)
return self._after_insert_item(r, item_id)
raise RuntimeError("???? (^^)")
def _after_insert_item(self, r: httpx.Response, item_id: Union[str, int])->Dict[str, Any]:
"""
return r_json if r.status_code == 200 else raise Exception(r.text)
"""
if r.status_code == 200:
r_json = r.json()
# print(f'[client<-tracker] insert_item item({item_id}). OK:', r_json)
logger.info(f'[client<-tracker] insert_item item({item_id}). OK:', r_json)
return r_json
raise Exception(r.text)

View File

@ -0,0 +1,70 @@
package savewebtracker
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"time"
)
func (t *Tracker) Project() (proj Project) {
if time.Since(t.__project_last_fetched) < 3*time.Minute {
return *t.__project
}
t.StartFetchProjectBackground()
for t.__project == nil {
fmt.Println("[client->tracker] waiting for project... ", t.project_id)
time.Sleep(5 * time.Second)
}
return t.Project()
}
func (t *Tracker) StartFetchProjectBackground() {
if t.__background_fetch_proj {
return
}
t.__background_fetch_proj = true
go func() {
for {
go t.FetchProject(20 * time.Second)
time.Sleep(1 * time.Minute)
}
}()
}
func (t *Tracker) FetchProject(timeout time.Duration) (proj *Project, err error) {
fmt.Println("[client->tracker] fetch_project... ", t.project_id)
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
time.AfterFunc(timeout, func() {
cancel()
})
req, err := http.NewRequestWithContext(ctx, "POST", t.API_BASE+t.API_VERSION+"/project/"+t.project_id, nil)
if err != nil {
log.Print(err)
return nil, err
}
r, err := t.http_client.Do(req)
if err != nil {
log.Print(err)
return nil, err
}
defer r.Body.Close()
if r.StatusCode != 200 {
return nil, errors.New("status code not 200")
}
proj = &Project{}
err = json.NewDecoder(r.Body).Decode(proj)
if err != nil {
return nil, err
}
t.__project = proj
t.__project_last_fetched = time.Now()
fmt.Println("[client<-tracker] fetch_project. ", t.project_id)
return proj, nil
}

View File

@ -0,0 +1,303 @@
package savewebtracker
import (
"bytes"
"compress/gzip"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type ObjectID string
type DatetimeUTC string
func (s Status) Validate() bool {
for _, v := range LIST_STATUS {
if v == s {
return true
}
}
return false
}
func (d DatetimeUTC) GetDatetime() time.Time {
// 2024-06-06T15:45:00.008Z
t, err := time.Parse(time.RFC3339, string(d))
if err != nil {
panic(err)
}
return t
}
type Task struct {
Obj_id ObjectID `json:"_id"`
Id string `json:"-"`
Id_raw json.RawMessage `json:"id"`
Id_type string `json:"-"` // str or int
Status Status `json:"status"`
Archivist string `json:"archivist"` // Optional
Claimed_at DatetimeUTC `json:"claimed_at"` // Optional
Updated_at DatetimeUTC `json:"updated_at"` // Optional
}
var (
ErrorClientVersionOutdated = errors.New("client version outdated")
ENABLE_GZIP = true
)
func (t *Tracker) ClaimTask(with_delay bool) *Task {
if with_delay {
t._claim_wait_lock.Lock()
time.Sleep(time.Duration(t.Project().Client.ClaimTaskDelay) * time.Second)
t._claim_wait_lock.Unlock()
}
resp, err := t.http_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/claim_task", "", nil)
if err != nil {
panic(err)
}
return _after_claim_task(resp)
}
// 无任务返回 nil
func _after_claim_task(r *http.Response) *Task {
if r.StatusCode == 404 {
return nil
}
if r.StatusCode == 200 {
task := Task{}
err := json.NewDecoder(r.Body).Decode(&task)
if err != nil {
panic(err)
}
var idInt int
var idString string
if err := json.Unmarshal(task.Id_raw, &idInt); err == nil {
idString = fmt.Sprintf("%d", idInt)
task.Id = idString
task.Id_type = "int"
} else if err := json.Unmarshal(task.Id_raw, &idString); err == nil {
task.Id = idString
task.Id_type = "str"
} else {
panic(err)
}
return &task
}
BodyBytes, _ := io.ReadAll(r.Body)
panic(string(BodyBytes))
}
func (t *Tracker) UpdateTask(task_id string, id_type string, to_status Status) string {
postData := url.Values{}
postData.Set("status", string(to_status))
postData.Set("task_id_type", id_type)
if !to_status.Validate() {
fmt.Println("invalid status, to_status:", to_status)
panic("invalid status")
}
resp, err := t.http_client.Post(t.API_BASE+t.API_VERSION+"/project/"+t.project_id+"/"+t.client_version+"/"+t.archivist+"/update_task/"+task_id, "application/x-www-form-urlencoded", strings.NewReader(postData.Encode()))
if err != nil {
panic(err)
}
return _after_update_task(resp)
}
func _after_update_task(r *http.Response) string {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
text := string(bodyBytes)
if r.StatusCode == 200 {
return text
}
if r.StatusCode == 400 {
panic(text)
}
fmt.Println(r.StatusCode, r.Request.URL, text)
panic(text)
}
// def _before_insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]], payload: Optional[dict]):
// # item_id_type
// if isinstance(item_id, int):
// item_id_type = 'int'
// elif isinstance(item_id, str):
// item_id_type = 'str'
// else:
// raise ValueError("item_id must be int or str")
// if item_status is None:
// status_type = "None"
// elif isinstance(item_status, int):
// status_type = "int"
// elif isinstance(item_status, str):
// status_type = "str"
// else:
// raise ValueError("status must be int, str or None")
// payload_json_str = json.dumps(payload, ensure_ascii=False, separators=(',', ':'))
// # print(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}")
// logger.info(f"[client->tracker] insert_item item({item_id}), len(payload)={len(payload_json_str)}")
// return item_id_type, status_type, payload_json_str
func (t *Tracker) InsertItem(task Task, item_status string, status_type string, payload string) string {
if status_type != "int" && status_type != "str" && status_type != "None" {
panic("status must be int, str or None")
}
req_url := t.API_BASE + t.API_VERSION + "/project/" + t.project_id + "/" + t.client_version + "/" + t.archivist + "/insert_item/" + task.Id
data := url.Values{}
data.Set("item_id_type", task.Id_type)
data.Set("item_status", item_status)
data.Set("item_status_type", status_type)
data.Set("payload", payload)
encodedData := data.Encode()
gzBuf := &bytes.Buffer{}
gz := t.__gzPool.Get().(*gzip.Writer)
defer t.__gzPool.Put(gz)
defer gz.Reset(io.Discard)
defer gz.Close()
gz.Reset(gzBuf)
if _, err := gz.Write([]byte(encodedData)); err != nil {
panic(err)
}
if err := gz.Flush(); err != nil {
panic(err)
}
gz.Close()
len_encodedData := len([]byte(encodedData))
// fmt.Printf("compressed %d -> %d \n", len_encodedData, gzBuf.Len())
req := &http.Request{}
var err error
if ENABLE_GZIP && float64(gzBuf.Len())/float64(len_encodedData) < 0.95 { // good compression rate
fmt.Println("using gzip...", gzBuf.Len())
// req, err = http.NewRequest("POST", req_url, gzBuf)
req, err = http.NewRequest("POST", req_url, gzBuf)
if err != nil {
panic(err)
}
req.Header.Set("Content-Encoding", "gzip")
} else {
req, err = http.NewRequest("POST", req_url, strings.NewReader(encodedData))
if err != nil {
panic(err)
}
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "*/*")
fmt.Println(req.Header)
resp, err := t.http_client.Do(req)
if err != nil {
panic(err)
}
return _after_insert_item(resp)
}
func _after_insert_item(r *http.Response) string {
defer r.Body.Close()
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
text := string(bodyBytes)
if r.StatusCode == 200 {
return text
}
fmt.Println(r.StatusCode, r.Request.URL, text)
panic(text)
}
// def insert_item(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None):
// """
// item_id: 必须明确传入的是 int 还是 str
// item_status: item 状态,而不是 task 状态。可用于标记一些被删除、被隐藏的 item 之类的。就不用添加一堆错误响应到 payload 里了。
// payload: 可以传入任意的可转为 ext-json 的对象 (包括 None)
// """
// item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload)
// r = self.sync_session.post(
// f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}',
// data={
// 'item_id_type': item_id_type,
// 'item_status': item_status,
// 'item_status_type': status_type,
// 'payload': payload_json_str
// })
// return self._after_insert_item(r, item_id)
// async def insert_item_async(self, item_id: Union[str, int], item_status: Optional[Union[str, int]] = None, payload: Optional[dict] = None):
// """
// item_id: 必须明确传入的是 int 还是 str
// item_status: item 状态,而不是 task 状态。可用于标记一些被删除、被隐藏的 item 之类的。就不用添加一堆错误响应到 payload 里了。
// payload: 可以传入任意的可转为 ext-json 的对象 (包括 None)
// """
// item_id_type, status_type, payload_json_str = self._before_insert_item(item_id, item_status, payload)
// data = {
// 'item_id_type': item_id_type,
// 'item_status': item_status,
// 'item_status_type': status_type,
// 'payload': payload_json_str
// }
// _, encoded_data_stream = encode_urlencoded_data(data)
// encoded_data = b''.join(encoded_data_stream)
// encoded_data_compressed = gzip.compress(encoded_data)
// rate = len(encoded_data_compressed) / len(encoded_data)
// if rate < 0.95:
// # low rate
// gzip_headers = { "Content-Encoding": "gzip", "Content-Type": "application/x-www-form-urlencoded"}
// logger.debug(f"compressed {len(encoded_data)} -> {len(encoded_data_compressed)}")
// try:
// r = await self.session.post(
// f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}',
// content=encoded_data_compressed, headers=gzip_headers)
// return self._after_insert_item(r, item_id)
// except Exception as e:
// print(f"insert fail, retry (without gzip compression)... {e}")
// # high rate or retry without compression
// r = await self.session.post(
// f'{self.API_BASE}{self.API_VERSION}/project/{self.project_id}/{self.client_version}/{self.archivist}/insert_item/{item_id}',
// data=data)
// return self._after_insert_item(r, item_id)
// raise RuntimeError("???? (^^)")
// def _after_insert_item(self, r: httpx.Response, item_id: Union[str, int])->Dict[str, Any]:
// """
// return r_json if r.status_code == 200 else raise Exception(r.text)
// """
// if r.status_code == 200:
// r_json = r.json()
// # print(f'[client<-tracker] insert_item item({item_id}). OK:', r_json)
// logger.info(f'[client<-tracker] insert_item item({item_id}). OK:', r_json)
// return r_json
// raise Exception(r.text)

View File

@ -0,0 +1,84 @@
package savewebtracker
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCliamTask(t *testing.T) {
tracker := GetDefaultTracker().SelectBestTracker()
task := tracker.ClaimTask(false)
if task == nil {
t.Error("task is nil")
}
fmt.Println(task)
json_str, _ := json.Marshal(&task)
fmt.Println(string(json_str))
if task != nil {
fmt.Println(
"Id_type", task.Id_type, '\n',
"Updated_at", task.Updated_at.GetDatetime(), '\n',
"Claimed_at", task.Claimed_at.GetDatetime(), '\n',
)
}
}
func TestCliamAndUpdateTask(t *testing.T) {
tracker := GetDefaultTracker().SelectBestTracker()
task := tracker.ClaimTask(false)
if task == nil {
t.Error("task is nil")
return
}
text := tracker.UpdateTask(task.Id, task.Id_type, StatusDONE)
fmt.Println("task_id:", task.Id)
fmt.Println(text)
}
func TestBytes(t *testing.T) {
b := "京哈随章长而窄喝着和"
assert.Equal(t, b, string([]byte(b)))
}
func TestClaimAndUpdateAndInsertTask(t *testing.T) {
tracker := GetDefaultTracker().SelectBestTracker()
task := tracker.ClaimTask(false)
if task == nil {
t.Error("task is nil")
return
}
text := tracker.UpdateTask(task.Id, task.Id_type, StatusDONE)
fmt.Println("task_id:", task.Id, text)
payload := map[string]interface{}{
"kjasd": "111111111122222111111111122222111111111122222111111111122222",
"abasr": "111111111122222111111111122222111111111122222111111111122222",
"list": []int{1, 2, 3, 4, 5, 6, 7, 8, 9},
"dict": map[string]interface{}{"a": 1, "b": 2, "c": 3},
"nested": map[string]interface{}{
"list": []int{1, 2, 3, 4, 5, 6, 7, 8, 9},
"nested": map[string]interface{}{
"list": []int{1, 2, 3, 4, 5, 6, 7, 8, 9},
},
},
}
payload_json_str, _ := json.Marshal(payload)
text = tracker.InsertItem(*task, "null", "None", string(payload_json_str))
fmt.Println("inserted task_id:", task.Id, text)
}
func TestClaimWithOldVersion(t *testing.T) {
tracker := GetDefaultTracker().SelectBestTracker()
tracker.client_version = "0.0"
assert.Panics(t, func() {
task := tracker.ClaimTask(false)
if task == nil {
t.Error("task is nil")
}
})
}

View File

@ -0,0 +1,36 @@
package savewebtracker
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func GetDefaultTracker() *Tracker {
return GetTracker("test", "1.0", "archivist-test")
}
func TestIsSafeString(t *testing.T) {
assert.True(t, _is_safe_string("abc"))
assert.True(t, _is_safe_string("abc-gg36c3rc-5vw_"))
assert.True(t, _is_safe_string(""))
assert.False(t, _is_safe_string("abc-5vw_!"))
assert.False(t, _is_safe_string("asdasd/.."))
}
func TestGetTracker(t *testing.T) {
tracker := GetDefaultTracker().SelectBestTracker()
assert.NotEmpty(t, tracker.API_BASE)
}
func TestGetProject(t *testing.T) {
tracker := GetDefaultTracker().SelectBestTracker()
assert.False(t, tracker.__background_fetch_proj)
proj := tracker.Project()
assert.True(t, tracker.__background_fetch_proj)
assert.NotNil(t, proj)
fmt.Println(proj)
same_proj := tracker.Project()
assert.Equal(t, proj, same_proj)
}

18
tests/tracker_test.py Normal file
View File

@ -0,0 +1,18 @@
import pytest
import saveweb_tracker
import saveweb_tracker.tracker
@pytest.mark.asyncio
async def test_tracker():
async with saveweb_tracker.tracker.Tracker("test", "1.0", "test-python") as tracker:
task = await tracker.claim_task_async(with_delay=False)
print(task)
assert task is not None
text = await tracker.update_task_async(task_id=task.id, status="DONE")
print(text)
text = await tracker.insert_item_async(
item_id=task.id, item_status=None, payload={
"kjasd": "111111111122222111111111122222111111111122222111111111122222111111111122222111111111122222111111111122222111111111122222",
}
)
print(text)