package file import ( "io" "linkfog.com/public/lib/file/config" "linkfog.com/public/pkg/ratelimit" ) func CopyRateLimit(dst io.Writer, src io.Reader) (written int64, err error) { if !config.EnableRateLimitBucket() || config.RateLimitBucket() == nil { written, err = io.Copy(dst, src) if err != nil { return written, err } return written, err } return copyBufferRateLimit(dst, src, nil) } // copyBufferRateLimit is the actual implementation of Copy and CopyBuffer. // if buf is nil, one is allocated. func copyBufferRateLimit(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { // If the reader has a WriteTo method, use it to do the copy. // Avoids an allocation and a copy. if wt, ok := src.(io.WriterTo); ok { return wt.WriteTo(dst) } // Similarly, if the writer has a ReadFrom method, use it to do the copy. if rt, ok := dst.(io.ReaderFrom); ok { return rt.ReadFrom(src) } if buf == nil { size := config.RateLimitBucket().Capacity() if l, ok := src.(*io.LimitedReader); ok && size > l.N { if l.N < 1 { size = 1 } else { size = l.N } } buf = make([]byte, size) } rateWriter := ratelimit.Writer(dst, config.RateLimitBucket()) for { nr, er := src.Read(buf) if nr > 0 { nw, ew := rateWriter.Write(buf[0:nr]) if nw < 0 || nr < nw { nw = 0 if ew == nil { ew = errInvalidWrite } } written += int64(nw) if ew != nil { err = ew break } if nr != nw { err = io.ErrShortWrite break } } if er != nil { if er != io.EOF { err = er } break } } return written, err } func CopyBufferRateLimit(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { if !config.EnableRateLimitBucket() || config.RateLimitBucket() == nil { written, err = io.CopyBuffer(dst, src, buf) if err != nil { return written, err } return written, err } return copyBufferRateLimit(dst, src, buf) } // the copy is implemented using it. func CopyNRateLimit(dst io.Writer, src io.Reader, n int64) (written int64, err error) { written, err = CopyRateLimit(dst, io.LimitReader(src, n)) if written == n { return n, nil } if written < n && err == nil { // src stopped early; must have been EOF. err = io.EOF } return }