Commit b2bed5e2 authored by Lei Li's avatar Lei Li
Browse files

feat: 增加自升级处理逻辑

parent 049a14ce
...@@ -5,7 +5,7 @@ import "time" ...@@ -5,7 +5,7 @@ import "time"
var ( var (
WorkDir = "/linkfog/agent/" WorkDir = "/linkfog/agent/"
BinDir = "bin" BackupDir = "backup"
PluginDir = "plugin" PluginDir = "plugin"
LogDir = "log" LogDir = "log"
LogFileName = "agent.log" LogFileName = "agent.log"
......
...@@ -27,6 +27,7 @@ var monitorDaemon *monitor_daemon.MonitorDaemon ...@@ -27,6 +27,7 @@ var monitorDaemon *monitor_daemon.MonitorDaemon
type Agent struct { type Agent struct {
IsStop bool IsStop bool
cmdline0 string
} }
func (a *Agent) Init(env svc.Environment) error { func (a *Agent) Init(env svc.Environment) error {
...@@ -46,6 +47,8 @@ func (a *Agent) Init(env svc.Environment) error { ...@@ -46,6 +47,8 @@ func (a *Agent) Init(env svc.Environment) error {
if option.Opt.EnableMonitorDaemon { if option.Opt.EnableMonitorDaemon {
config.LogFileName = config.DaemonLogFileName config.LogFileName = config.DaemonLogFileName
} else {
a.cmdline0 = os.Args[0]
} }
// 日志设置 // 日志设置
...@@ -79,7 +82,7 @@ func (a *Agent) Start() error { ...@@ -79,7 +82,7 @@ func (a *Agent) Start() error {
monitor_daemon.WithMaxInterval(option.Opt.MonitorDefaultMaxInterval), // 设置重新拉起进程的最大时间间隔,防止频繁重启带来的系统压力 monitor_daemon.WithMaxInterval(option.Opt.MonitorDefaultMaxInterval), // 设置重新拉起进程的最大时间间隔,防止频繁重启带来的系统压力
monitor_daemon.WithIncrInterval(option.Opt.MonitorDefaultIncrInterval), // 设置重新拉起进程的时间间隔,防止频繁重启带来的系统压力 monitor_daemon.WithIncrInterval(option.Opt.MonitorDefaultIncrInterval), // 设置重新拉起进程的时间间隔,防止频繁重启带来的系统压力
monitor_daemon.WithMaxRetryTimes(option.Opt.MonitorDefaultMaxRetryTimes), monitor_daemon.WithMaxRetryTimes(option.Opt.MonitorDefaultMaxRetryTimes),
monitor_daemon.WithAgentBackupPath(filepath.Join(config.BinDir, option.Opt.MonitorDefaultBackupFileName)), monitor_daemon.WithAgentBackupPath(filepath.Join(config.BackupDir, option.Opt.MonitorDefaultBackupFileName)),
} }
if option.Opt.EnableProcAbnormalCb { if option.Opt.EnableProcAbnormalCb {
opts = append(opts, monitor_daemon.WithIsProcAbnormalCallback(IsProcAbnormal)) opts = append(opts, monitor_daemon.WithIsProcAbnormalCallback(IsProcAbnormal))
...@@ -157,6 +160,10 @@ func initSvcCfg() error { ...@@ -157,6 +160,10 @@ func initSvcCfg() error {
} }
global.SubscribePrefixInfo = global.DeviceSerialNumber + "/publish/" global.SubscribePrefixInfo = global.DeviceSerialNumber + "/publish/"
// NOTE: 自测使用
if config.Edition == "dev" {
config.WorkDir = option.Opt.WorkDir
}
// 时间修正 // 时间修正
if !iTime.SynchronizeSystemTime() { if !iTime.SynchronizeSystemTime() {
return fmt.Errorf("synchronization system time failed") return fmt.Errorf("synchronization system time failed")
......
package core package core
import ( import (
"agent/module" "agent/pkg/file"
"os"
"path/filepath"
"strings"
"syscall"
"time" "time"
"agent/cmd/agent/config"
"agent/cmd/agent/global" "agent/cmd/agent/global"
"agent/cmd/agent/option"
"agent/module"
"linkfog.com/public/lib/l" "linkfog.com/public/lib/l"
) )
...@@ -28,6 +35,8 @@ func (a *Agent) chatMsgProcess(msg *global.Message) { ...@@ -28,6 +35,8 @@ func (a *Agent) chatMsgProcess(msg *global.Message) {
switch msg.Key { switch msg.Key {
case global.ConsumerTopicAgentSelfUpgrade: case global.ConsumerTopicAgentSelfUpgrade:
go a.SelfUpgradeProcess()
return
case global.ConsumerTopicPluginUpgrade: case global.ConsumerTopicPluginUpgrade:
fallthrough fallthrough
case global.ConsumerTopicStartupPlugin: case global.ConsumerTopicStartupPlugin:
...@@ -40,3 +49,31 @@ func (a *Agent) chatMsgProcess(msg *global.Message) { ...@@ -40,3 +49,31 @@ func (a *Agent) chatMsgProcess(msg *global.Message) {
} }
} }
} }
func (a *Agent) SelfUpgradeProcess() {
execFile := a.cmdline0[strings.LastIndex(a.cmdline0, "/")+1:]
// 备份
err := os.Rename(a.cmdline0, config.BackupDir+"/"+execFile)
if err != nil {
l.Error(err)
return
}
// 下载
err = global.DownloadFile(option.Opt.DownloadURL+"/"+execFile, config.WorkDir, execFile)
if err != nil {
l.Errorf("download file err:%v", err)
return
}
// 增加执行权限
err = file.SetFileExecPerm(filepath.Join(config.WorkDir, execFile))
if err != nil {
l.Error(err)
}
// TODO md5校验
// kill自身
err = syscall.Kill(os.Getpid(), syscall.SIGTERM)
if err != nil {
l.Error(err)
}
}
...@@ -87,14 +87,14 @@ func SetWorkDir() error { ...@@ -87,14 +87,14 @@ func SetWorkDir() error {
return err return err
} }
config.BinDir = path.Join(config.WorkDir, config.BinDir) config.BackupDir = path.Join(config.WorkDir, config.BackupDir)
err = os.MkdirAll(config.BinDir, 0644) err = os.MkdirAll(config.BackupDir, 0644)
if err != nil { if err != nil {
return err return err
} }
config.PluginDir = path.Join(config.WorkDir, config.PluginDir) config.PluginDir = path.Join(config.WorkDir, config.PluginDir)
err = os.MkdirAll(config.BinDir, 0644) err = os.MkdirAll(config.PluginDir, 0644)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -6,6 +6,8 @@ import ( ...@@ -6,6 +6,8 @@ import (
"crypto/sha512" "crypto/sha512"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
...@@ -85,3 +87,23 @@ func produceDeviceID() (string, error) { ...@@ -85,3 +87,23 @@ func produceDeviceID() (string, error) {
return string(snByte), nil return string(snByte), nil
} }
func DownloadFile(url, filepath, name string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("download file failed, err:%v", err)
}
defer resp.Body.Close()
file, err := os.Create(path.Join(filepath, name))
if err != nil {
return fmt.Errorf("create file failed, err:%v", err)
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
return fmt.Errorf("copy file failed, err:%v", err)
}
return nil
}
...@@ -43,6 +43,8 @@ type options struct { ...@@ -43,6 +43,8 @@ type options struct {
// 调试信息 // 调试信息
MQTTBrokerURL string MQTTBrokerURL string
MQTTBrokerPort int MQTTBrokerPort int
DownloadURL string
WorkDir string
} }
const ( const (
...@@ -73,8 +75,11 @@ const ( ...@@ -73,8 +75,11 @@ const (
func flagSet() { func flagSet() {
if config.Edition == "dev" { if config.Edition == "dev" {
// NOTE: 自测使用
flag.StringVar(&Opt.MQTTBrokerURL, "mqtt-url", "tcp://broker.hivemq.com", "MQTT broker url") flag.StringVar(&Opt.MQTTBrokerURL, "mqtt-url", "tcp://broker.hivemq.com", "MQTT broker url")
flag.IntVar(&Opt.MQTTBrokerPort, "mqtt-port", 1883, "MQTT broker port") flag.IntVar(&Opt.MQTTBrokerPort, "mqtt-port", 1883, "MQTT broker port")
flag.StringVar(&Opt.DownloadURL, "download-url", "http://fae-cdn.linkfog.cn/tmp", "download url")
flag.StringVar(&Opt.WorkDir, "work-dir", "", "work dir")
} }
flag.BoolVar(&Opt.Usage, ArgUsage, false, "custom format usage") flag.BoolVar(&Opt.Usage, ArgUsage, false, "custom format usage")
flag.BoolVar(&Opt.PrintVersion, ArgVersion, false, "print version and exit") flag.BoolVar(&Opt.PrintVersion, ArgVersion, false, "print version and exit")
......
...@@ -40,6 +40,7 @@ func New(opts ...BackendOpt) *Backend { ...@@ -40,6 +40,7 @@ func New(opts ...BackendOpt) *Backend {
clientOpts.AddBroker(fmt.Sprintf("ssl://%s", b.BrokerIP)) clientOpts.AddBroker(fmt.Sprintf("ssl://%s", b.BrokerIP))
} }
// NOTE: 自测使用
if config.Edition == "dev" { if config.Edition == "dev" {
clientOpts.AddBroker(fmt.Sprintf("%s:%d", option.Opt.MQTTBrokerURL, option.Opt.MQTTBrokerPort)) clientOpts.AddBroker(fmt.Sprintf("%s:%d", option.Opt.MQTTBrokerURL, option.Opt.MQTTBrokerPort))
} }
...@@ -101,7 +102,7 @@ func (b *Backend) Start() error { ...@@ -101,7 +102,7 @@ func (b *Backend) Start() error {
b.client.Publish(topic, 1, false, []byte(msg.Payload)) b.client.Publish(topic, 1, false, []byte(msg.Payload))
} }
case <-b.signal: case <-b.signal:
l.Info("MQTT publish exit") l.Info("MQTT exit")
return return
} }
} }
......
package file
import (
"os"
)
func SetFileExecPerm(file string) error {
fileInfo, err := os.Stat(file)
if err != nil {
return err
}
newMode := fileInfo.Mode() | 0111
err = os.Chmod(file, newMode)
if err != nil {
return err
}
return nil
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment