plugin_grpc_client.go 4.01 KB
Newer Older
“李磊”'s avatar
“李磊” committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
package pluginrpc

import (
	"context"
	"fmt"
	"io"
	"os"
	"strings"

	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"

	"linkfog.com/public/lib/file"

	pb "linkfog.com/pluginx/proto"
)

var (
	FileTransferBlockSize = 102400 // unit:bytes
	FileTransferTimeout   = 30     // unit:minutes
)

// 插件grpc客户端
type PluginGrpcClient struct {
	pluginMap  map[string]*pluginClientConn
	pluginConf map[string]string // key:name value:unixPath
}

type pluginClientConn struct {
	conn   *grpc.ClientConn
	client pb.PluginClient
}

func NewPluginGrpcClient(pluginConf map[string]string) (*PluginGrpcClient, error) {
	h := &PluginGrpcClient{
		pluginMap:  make(map[string]*pluginClientConn),
		pluginConf: pluginConf,
	}

	for name, unixPath := range pluginConf {
		if !strings.HasPrefix(unixPath, "unix:") {
			unixPath = "unix:" + unixPath
		}
		conn, err := grpc.Dial(unixPath,
			grpc.WithTransportCredentials(insecure.NewCredentials()),
			// grpc.WithKeepaliveParams(keepalive.ClientParameters{
			// 	Time:    10 * time.Second,
			// 	Timeout: 20 * time.Second,
			// }),
		)
		if err != nil {
			h.Close()
			return nil, fmt.Errorf("failed to dial %s, err:%v", unixPath, err)
		}

		client := pb.NewPluginClient(conn)

		h.pluginMap[name] = &pluginClientConn{conn: conn, client: client}
	}

	return h, nil
}

func (h *PluginGrpcClient) Close() {
	for _, cc := range h.pluginMap {
		cc.conn.Close()
	}
}

func (c *PluginGrpcClient) Call(ctx context.Context, req *pb.Req) (*pb.Res, error) {
	// 查找目的插件对应的client
	dstPlugin := req.Header.To
	cc, ok := c.pluginMap[dstPlugin]
	if !ok {
		return nil, fmt.Errorf("dst plugin conn %s not found", dstPlugin)
	}

	return cc.client.Call(ctx, req)
}

func (c *PluginGrpcClient) SendFile(ctx context.Context, fs *pb.FileStream, filePath string) (*pb.Res, error) {
	// 查找目的插件对应的client
	dstPlugin := fs.Header.To
	cc, ok := c.pluginMap[dstPlugin]
	if !ok {
		return nil, fmt.Errorf("dst plugin conn %s not found", dstPlugin)
	}

	info, err := os.Stat(filePath)
	if err != nil {
		return nil, fmt.Errorf("stat file err: %v", err)
	}

	if fs.TotalSize == 0 {
		fs.TotalSize = info.Size()
	}
	if fs.TotalPart == 0 {
		fs.TotalPart = calcTotalPart(info.Size(), int64(FileTransferBlockSize))
	}

	stream, err := cc.client.SendFile(ctx)
	if err != nil {
		return nil, fmt.Errorf("client.SendFile err: %v", err)
	}

	fs.Part = 0
	err = sendFileByPart(filePath, func(part []byte) error {
		fs.Data = part
		fs.Part++
		err = stream.Send(fs)
		if err != nil {
			return fmt.Errorf("stream send err: %v", err)
		}
		return nil
	})
	if err != nil {
		return nil, fmt.Errorf("sendFileByPart err: %v", err)
	}

	res, err := stream.CloseAndRecv()
	if err != nil {
		return nil, fmt.Errorf("stream CloseAndRecv err: %v", err)
	}

	return res, nil
}

func (c *PluginGrpcClient) Chat(ctx context.Context, req *pb.Req) (pb.Plugin_ChatClient, error) {
	// 查找目的插件对应的client
	dstPlugin := req.Header.To
	cc, ok := c.pluginMap[dstPlugin]
	if !ok {
		return nil, fmt.Errorf("dst plugin conn %s not found", dstPlugin)
	}

	return cc.client.Chat(ctx, req)
}

func sendFileByPart(filePath string, send func([]byte) error) error {
	f, err := os.Open(filePath)
	if err != nil {
		return fmt.Errorf("open file err: %v", err)
	}
	defer f.Close()
	defer file.FadviseSwitch(f)

	var buf = make([]byte, FileTransferBlockSize)
	var finished = false
	for offset := int64(0); ; offset += int64(FileTransferBlockSize) {
		readLength, err := f.ReadAt(buf, offset)
		if err == io.EOF {
			finished = true
		} else if err != nil {
			return fmt.Errorf("err occured when reading file: %s, err: %v", filePath, err)
		}

		if readLength == 0 {
			break
		}

		if readLength != FileTransferBlockSize {
			// trailing garbage
			buf = buf[:readLength]
		}

		err = send(buf)
		if err != nil {
			return fmt.Errorf("send err: %v", err)
		}

		if finished {
			break
		}
	}

	return nil
}

func calcTotalPart(totalSize, partSize int64) int64 {
	if totalSize%partSize == 0 {
		return totalSize / partSize
	}
	return totalSize/partSize + 1
}