feat: add archivist util
This commit is contained in:
parent
67c2c0dfbc
commit
f9ec76dcc8
70
src/saveweb_tracker/archivist.go
Normal file
70
src/saveweb_tracker/archivist.go
Normal file
@ -0,0 +1,70 @@
|
||||
package savewebtracker
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func _newArchivist() string {
|
||||
fmt.Println("zh: 第一次运行,请输入可以唯一标识您节点的字符串,例如 alice-aws-114。(合法字符:字母、数字、-、_)")
|
||||
fmt.Println("en: First run, please input a string that can uniquely identify your node (for example: bob-gcloud-514). (Legal characters: letters, numbers, -, _)")
|
||||
fmt.Print("ARCHIVIST: ")
|
||||
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
if scanner.Scan() {
|
||||
arch := scanner.Text()
|
||||
file, err := os.Create("ARCHIVIST.conf")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = file.WriteString(arch + "\n")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return _getArchivist()
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
panic("no input")
|
||||
}
|
||||
|
||||
func Archivist() string {
|
||||
arch := _getArchivist()
|
||||
if arch == "" {
|
||||
arch = _newArchivist()
|
||||
}
|
||||
if !_is_safe_string(arch) {
|
||||
panic("ARCHIVIST contains illegal characters")
|
||||
}
|
||||
|
||||
return arch
|
||||
|
||||
}
|
||||
|
||||
func _getArchivist() string {
|
||||
arch, exists := os.LookupEnv("ARCHIVIST")
|
||||
if exists {
|
||||
return arch
|
||||
}
|
||||
|
||||
if _, err := os.Stat("ARCHIVIST.conf"); err == nil {
|
||||
file, err := os.Open("ARCHIVIST.conf")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
if scanner.Scan() {
|
||||
return scanner.Text()
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
61
src/saveweb_tracker/archivist_test.go
Normal file
61
src/saveweb_tracker/archivist_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package savewebtracker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewArchivist(t *testing.T) {
|
||||
|
||||
// let's sent "test" to stdin
|
||||
tmpFile, err := os.CreateTemp("", "stdin")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, err = tmpFile.WriteString("test\nasdsadasdsd\n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tmpFile.Seek(0, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
os.Stdin = tmpFile
|
||||
|
||||
arch := _newArchivist()
|
||||
fmt.Println(arch)
|
||||
assert.Equal(t, "test", arch)
|
||||
}
|
||||
|
||||
func TestGetArchivist(t *testing.T) {
|
||||
// let's set an environment variable
|
||||
os.Setenv("ARCHIVIST", "test")
|
||||
|
||||
arch := Archivist()
|
||||
assert.Equal(t, "test", arch)
|
||||
}
|
||||
|
||||
func TestGetIllegalArchivist(t *testing.T) {
|
||||
// let's set an environment variable
|
||||
os.Setenv("ARCHIVIST", "test../")
|
||||
|
||||
assert.Panics(t, func() {
|
||||
Archivist()
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetEmptyArchivist(t *testing.T) {
|
||||
// let's set an environment variable
|
||||
os.Setenv("ARCHIVIST", "")
|
||||
|
||||
assert.Panics(t, func() {
|
||||
Archivist()
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user