package pluginrpc import ( "context" "fmt" "io" "os" "strings" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "linkfog.com/public/lib/file" pb "linkfog.com/pluginx/proto" ) var ( FileTransferBlockSize = 102400 // unit:bytes FileTransferTimeout = 30 // unit:minutes ) // 插件grpc客户端 type PluginGrpcClient struct { pluginMap map[string]*pluginClientConn pluginConf map[string]string // key:name value:unixPath } type pluginClientConn struct { conn *grpc.ClientConn client pb.PluginClient } func NewPluginGrpcClient(pluginConf map[string]string) (*PluginGrpcClient, error) { h := &PluginGrpcClient{ pluginMap: make(map[string]*pluginClientConn), pluginConf: pluginConf, } for name, unixPath := range pluginConf { if !strings.HasPrefix(unixPath, "unix:") { unixPath = "unix:" + unixPath } conn, err := grpc.Dial(unixPath, grpc.WithTransportCredentials(insecure.NewCredentials()), // grpc.WithKeepaliveParams(keepalive.ClientParameters{ // Time: 10 * time.Second, // Timeout: 20 * time.Second, // }), ) if err != nil { h.Close() return nil, fmt.Errorf("failed to dial %s, err:%v", unixPath, err) } client := pb.NewPluginClient(conn) h.pluginMap[name] = &pluginClientConn{conn: conn, client: client} } return h, nil } func (h *PluginGrpcClient) Close() { for _, cc := range h.pluginMap { cc.conn.Close() } } func (c *PluginGrpcClient) Call(ctx context.Context, req *pb.Req) (*pb.Res, error) { // 查找目的插件对应的client dstPlugin := req.Header.To cc, ok := c.pluginMap[dstPlugin] if !ok { return nil, fmt.Errorf("dst plugin conn %s not found", dstPlugin) } return cc.client.Call(ctx, req) } func (c *PluginGrpcClient) SendFile(ctx context.Context, fs *pb.FileStream, filePath string) (*pb.Res, error) { // 查找目的插件对应的client dstPlugin := fs.Header.To cc, ok := c.pluginMap[dstPlugin] if !ok { return nil, fmt.Errorf("dst plugin conn %s not found", dstPlugin) } info, err := os.Stat(filePath) if err != nil { return nil, fmt.Errorf("stat file err: %v", err) } if fs.TotalSize == 0 { fs.TotalSize = info.Size() } if fs.TotalPart == 0 { fs.TotalPart = calcTotalPart(info.Size(), int64(FileTransferBlockSize)) } stream, err := cc.client.SendFile(ctx) if err != nil { return nil, fmt.Errorf("client.SendFile err: %v", err) } fs.Part = 0 err = sendFileByPart(filePath, func(part []byte) error { fs.Data = part fs.Part++ err = stream.Send(fs) if err != nil { return fmt.Errorf("stream send err: %v", err) } return nil }) if err != nil { return nil, fmt.Errorf("sendFileByPart err: %v", err) } res, err := stream.CloseAndRecv() if err != nil { return nil, fmt.Errorf("stream CloseAndRecv err: %v", err) } return res, nil } func (c *PluginGrpcClient) Chat(ctx context.Context, req *pb.Req) (pb.Plugin_ChatClient, error) { // 查找目的插件对应的client dstPlugin := req.Header.To cc, ok := c.pluginMap[dstPlugin] if !ok { return nil, fmt.Errorf("dst plugin conn %s not found", dstPlugin) } return cc.client.Chat(ctx, req) } func sendFileByPart(filePath string, send func([]byte) error) error { f, err := os.Open(filePath) if err != nil { return fmt.Errorf("open file err: %v", err) } defer f.Close() defer file.FadviseSwitch(f) var buf = make([]byte, FileTransferBlockSize) var finished = false for offset := int64(0); ; offset += int64(FileTransferBlockSize) { readLength, err := f.ReadAt(buf, offset) if err == io.EOF { finished = true } else if err != nil { return fmt.Errorf("err occured when reading file: %s, err: %v", filePath, err) } if readLength == 0 { break } if readLength != FileTransferBlockSize { // trailing garbage buf = buf[:readLength] } err = send(buf) if err != nil { return fmt.Errorf("send err: %v", err) } if finished { break } } return nil } func calcTotalPart(totalSize, partSize int64) int64 { if totalSize%partSize == 0 { return totalSize / partSize } return totalSize/partSize + 1 }