From f9ec76dcc8fd95127898db981b131e2e04627058 Mon Sep 17 00:00:00 2001 From: yzqzss Date: Fri, 7 Jun 2024 17:56:23 +0800 Subject: [PATCH] feat: add archivist util --- src/saveweb_tracker/archivist.go | 70 +++++++++++++++++++++++++++ src/saveweb_tracker/archivist_test.go | 61 +++++++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 src/saveweb_tracker/archivist.go create mode 100644 src/saveweb_tracker/archivist_test.go diff --git a/src/saveweb_tracker/archivist.go b/src/saveweb_tracker/archivist.go new file mode 100644 index 0000000..68ed740 --- /dev/null +++ b/src/saveweb_tracker/archivist.go @@ -0,0 +1,70 @@ +package savewebtracker + +import ( + "bufio" + "fmt" + "os" +) + +func _newArchivist() string { + fmt.Println("zh: 第一次运行,请输入可以唯一标识您节点的字符串,例如 alice-aws-114。(合法字符:字母、数字、-、_)") + fmt.Println("en: First run, please input a string that can uniquely identify your node (for example: bob-gcloud-514). (Legal characters: letters, numbers, -, _)") + fmt.Print("ARCHIVIST: ") + + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + arch := scanner.Text() + file, err := os.Create("ARCHIVIST.conf") + if err != nil { + panic(err) + } + defer file.Close() + + _, err = file.WriteString(arch + "\n") + if err != nil { + panic(err) + } + return _getArchivist() + } + if err := scanner.Err(); err != nil { + panic(err) + } + panic("no input") +} + +func Archivist() string { + arch := _getArchivist() + if arch == "" { + arch = _newArchivist() + } + if !_is_safe_string(arch) { + panic("ARCHIVIST contains illegal characters") + } + + return arch + +} + +func _getArchivist() string { + arch, exists := os.LookupEnv("ARCHIVIST") + if exists { + return arch + } + + if _, err := os.Stat("ARCHIVIST.conf"); err == nil { + file, err := os.Open("ARCHIVIST.conf") + if err != nil { + panic(err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + if scanner.Scan() { + return scanner.Text() + } + if err := scanner.Err(); err != nil { + panic(err) + } + } + return "" +} diff --git a/src/saveweb_tracker/archivist_test.go b/src/saveweb_tracker/archivist_test.go new file mode 100644 index 0000000..4d4aa6e --- /dev/null +++ b/src/saveweb_tracker/archivist_test.go @@ -0,0 +1,61 @@ +package savewebtracker + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewArchivist(t *testing.T) { + + // let's sent "test" to stdin + tmpFile, err := os.CreateTemp("", "stdin") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + _, err = tmpFile.WriteString("test\nasdsadasdsd\n") + if err != nil { + t.Fatal(err) + } + + _, err = tmpFile.Seek(0, 0) + if err != nil { + t.Fatal(err) + } + + os.Stdin = tmpFile + + arch := _newArchivist() + fmt.Println(arch) + assert.Equal(t, "test", arch) +} + +func TestGetArchivist(t *testing.T) { + // let's set an environment variable + os.Setenv("ARCHIVIST", "test") + + arch := Archivist() + assert.Equal(t, "test", arch) +} + +func TestGetIllegalArchivist(t *testing.T) { + // let's set an environment variable + os.Setenv("ARCHIVIST", "test../") + + assert.Panics(t, func() { + Archivist() + }) +} + +func TestGetEmptyArchivist(t *testing.T) { + // let's set an environment variable + os.Setenv("ARCHIVIST", "") + + assert.Panics(t, func() { + Archivist() + }) +}