copy.go 2.31 KB
Newer Older
“李磊”'s avatar
“李磊” committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package file

import (
	"io"

	"linkfog.com/public/lib/file/config"
	"linkfog.com/public/pkg/ratelimit"
)

func CopyRateLimit(dst io.Writer, src io.Reader) (written int64, err error) {
	if !config.EnableRateLimitBucket() || config.RateLimitBucket() == nil {
		written, err = io.Copy(dst, src)
		if err != nil {
			return written, err
		}
		return written, err
	}
	return copyBufferRateLimit(dst, src, nil)
}

// copyBufferRateLimit is the actual implementation of Copy and CopyBuffer.
// if buf is nil, one is allocated.
func copyBufferRateLimit(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) {
	// If the reader has a WriteTo method, use it to do the copy.
	// Avoids an allocation and a copy.
	if wt, ok := src.(io.WriterTo); ok {
		return wt.WriteTo(dst)
	}
	// Similarly, if the writer has a ReadFrom method, use it to do the copy.
	if rt, ok := dst.(io.ReaderFrom); ok {
		return rt.ReadFrom(src)
	}
	if buf == nil {
		size := config.RateLimitBucket().Capacity()
		if l, ok := src.(*io.LimitedReader); ok && size > l.N {
			if l.N < 1 {
				size = 1
			} else {
				size = l.N
			}
		}
		buf = make([]byte, size)
	}
	rateWriter := ratelimit.Writer(dst, config.RateLimitBucket())
	for {
		nr, er := src.Read(buf)
		if nr > 0 {
			nw, ew := rateWriter.Write(buf[0:nr])
			if nw < 0 || nr < nw {
				nw = 0
				if ew == nil {
					ew = errInvalidWrite
				}
			}
			written += int64(nw)
			if ew != nil {
				err = ew
				break
			}
			if nr != nw {
				err = io.ErrShortWrite
				break
			}
		}
		if er != nil {
			if er != io.EOF {
				err = er
			}
			break
		}
	}
	return written, err
}

func CopyBufferRateLimit(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) {
	if !config.EnableRateLimitBucket() || config.RateLimitBucket() == nil {
		written, err = io.CopyBuffer(dst, src, buf)
		if err != nil {
			return written, err
		}
		return written, err
	}
	return copyBufferRateLimit(dst, src, buf)
}

// the copy is implemented using it.
func CopyNRateLimit(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
	written, err = CopyRateLimit(dst, io.LimitReader(src, n))
	if written == n {
		return n, nil
	}
	if written < n && err == nil {
		// src stopped early; must have been EOF.
		err = io.EOF
	}
	return
}