pretty code

2021年3月4日 星期四

Golang data race with fmt package

Goruntine 是 Go 語言裡面一個很重要的功能,它是一個輕量化的執行緒,由 Go runtime 負責去排程執行,據官方說法,同時開啟好幾千個也沒問題。

使用 Golang 開發專案時或多或少都會使用到 Goruntine,我們也很習慣在 Goruntine 裡面直接使用 fmt package,但 fmt 並不保證 concurrency safe。

最近我在專案中使用了 GINlogrotate,發現在某些情況下會導致 data race,底下便是我精簡過的程式碼。


package main

import (
    "bufio"
    "errors"
    "path/filepath"
    "fmt"
    "log"
    "os"
    "sync"
    "time"
)

type Options struct {
    Directory string
    FileNameFunc func() string
}

type Writer struct {
    logger *log.Logger

    opts Options

    f *os.File
    bw *bufio.Writer
    bytesWritten int64

    queue chan []byte
    pending sync.WaitGroup
    closing chan struct{}
    done chan struct{}
}

func (w *Writer) Write(p []byte) (n int, err error) {
    //p := make([]byte, len(b))
    //copy(p, b)

    select {
    case <-w.closing:
        return 0, errors.New("writer is closing")
    default:
        w.pending.Add(1)
        defer w.pending.Done()
    }

    w.queue <- p

    return len(p), nil
}

func (w *Writer) Close() error {
    close(w.closing)
    w.pending.Wait()

    close(w.queue)
    <-w.done

    if w.f != nil {
        if err := w.closeCurrentFile(); err != nil {
            return err
        }
    }

    return nil
}

func (w *Writer) listen() {
    for b := range w.queue {
        if w.f == nil {
            if err := w.rotate(); err != nil {
                w.logger.Println("Failed to create log file", err)
            }
        }

        size := int64(len(b))

        if _, err := w.bw.Write(b); err != nil {
            w.logger.Println("Failed to write to file.", err)
        }
        w.bytesWritten += size
    }

    close(w.done)
}

func (w *Writer) closeCurrentFile() error {
    if err := w.bw.Flush(); err != nil {
        return errors.New("failed to flush buffered writer")
    }

    if err := w.f.Sync(); err != nil {
        return errors.New("failed to sync current log file")
    }

    if err := w.f.Close(); err != nil {
        return errors.New("failed to close current log file")
    }

    w.bytesWritten = 0
    return nil
}

func (w *Writer) rotate() error {
    if w.f != nil {
        if err := w.closeCurrentFile(); err != nil {
            return err
        }
    }

    path := filepath.Join(w.opts.Directory, w.opts.FileNameFunc())
    f, err := newFile(path)
    if err != nil {
        return errors.New("failed to create new file")
    }

    w.bw = bufio.NewWriter(f)
    w.f = f
    w.bytesWritten = 0

    return nil
}

func New(logger *log.Logger, opts Options) (*Writer, error) {
    if _, err := os.Stat(opts.Directory); os.IsNotExist(err) {
        if err := os.MkdirAll(opts.Directory, os.ModePerm); err != nil {
            return nil, errors.New("directory does not exist and could not be created")
        }
    }

    w := &Writer{
        logger:  logger,
        opts:    opts,
        queue:   make(chan []byte, 2000),
        closing: make(chan struct{}),
        done:    make(chan struct{}),
    }

    go w.listen()

    return w, nil
}

func newFile(path string) (*os.File, error) {
    return os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0666)
}

func test(wg *sync.WaitGroup, w *Writer) {
    i := 0

    var s2 string
    for i < 1 {
        s2 = "abcdefg"
        fmt.Println(s2)
        fmt.Fprint(w, s2)
        i++
    }

    wg.Done()
}

func FileNameFunc() (string) {
    t := time.Now()

    return fmt.Sprintf("%04d%02d%02d_%02d%02d%02d.log",
        t.Year(),
        t.Month(),
        t.Day(),
        t.Hour(),
        t.Minute(),
        t.Second(),
    )
}

func main() {
    var wg sync.WaitGroup

    logger := log.New(os.Stderr, "logrotate", 0)

    dir, err := os.Getwd()
    if err != nil {
        fmt.Println(err)
        return
    }

    opts := Options{
        Directory: dir,
        FileNameFunc: FileNameFunc,
    }

    writer, err := New(logger, opts)
    if err != nil {
        logger.Println(err)
        return
    }

    max := 2000
    i := 0
    for i < max {
        wg.Add(1)
        go test(&wg, writer)
        i++
    }

    wg.Wait()
}

這是測試的結果。

==================
WARNING: DATA RACE
Read at 0x00c00000a523 by goroutine 7:
  runtime.slicecopy()
      c:/go/src/runtime/slice.go:197 +0x0
  bufio.(*Writer).Write()
      c:/go/src/bufio/bufio.go:635 +0x3e5
  main.(*Writer).listen()
      E:/test.go:77 +0x1ab

Previous write at 0x00c00000a523 by goroutine 185:
  runtime.slicestringcopy()
      c:/go/src/runtime/slice.go:232 +0x0
  fmt.(*buffer).writeString()
      c:/go/src/fmt/print.go:82 +0x107
  fmt.(*fmt).padString()
      c:/go/src/fmt/format.go:110 +0x6f
  fmt.(*fmt).fmtS()
      c:/go/src/fmt/format.go:359 +0x75
  fmt.(*pp).fmtString()
      c:/go/src/fmt/print.go:447 +0x1ba
  fmt.(*pp).printArg()
      c:/go/src/fmt/print.go:698 +0xdcf
  fmt.(*pp).doPrint()
      c:/go/src/fmt/print.go:1161 +0x12c
  fmt.Fprint()
      c:/go/src/fmt/print.go:232 +0x6c
  main.test()
      E:/test.go:154 +0x130

Goroutine 7 (running) created at:
  main.New()
      E:/test.go:138 +0x29b
  main.main()
      E:/test.go:190 +0x271

Goroutine 185 (finished) created at:
  main.main()
      E:/test.go:200 +0x2e0
==================
==================
WARNING: DATA RACE
Read at 0x00c00008b1bd by goroutine 7:
  runtime.slicecopy()
      c:/go/src/runtime/slice.go:197 +0x0
  bufio.(*Writer).Write()
      c:/go/src/bufio/bufio.go:635 +0x3e5
  main.(*Writer).listen()
      E:/test.go:77 +0x1ab

Previous write at 0x00c00008b1bd by goroutine 702:
  runtime.slicestringcopy()
      c:/go/src/runtime/slice.go:232 +0x0
  fmt.(*buffer).writeString()
      c:/go/src/fmt/print.go:82 +0x107
  fmt.(*fmt).padString()
      c:/go/src/fmt/format.go:110 +0x6f
  fmt.(*fmt).fmtS()
      c:/go/src/fmt/format.go:359 +0x75
  fmt.(*pp).fmtString()
      c:/go/src/fmt/print.go:447 +0x1ba
  fmt.(*pp).printArg()
      c:/go/src/fmt/print.go:698 +0xdcf
  fmt.(*pp).doPrintln()
      c:/go/src/fmt/print.go:1173 +0xb4
  fmt.Fprintln()
      c:/go/src/fmt/print.go:264 +0x6c
  fmt.Println()
      c:/go/src/fmt/print.go:274 +0xc0
  main.test()
      E:/test.go:153 +0x42

Goroutine 7 (running) created at:
  main.New()
      E:/test.go:138 +0x29b
  main.main()
      E:/test.go:190 +0x271

Goroutine 702 (running) created at:
  main.main()
      E:/test.go:200 +0x2e0
==================
Found 2 data race(s)
exit status 66


從上面我們可以看到,不管是 fmt.Println 或是 fmt.Fprintf,都有機會引起 data race!

目前我的解決方式是把 buffer 先複製起來,再傳進 logrorate 的 channel 裡。

p := make([]byte, len(b))
copy(p, b)

解題靈感來自

沒有留言: