diff --git a/config/config.go b/config/config.go index d3f25cf40..487dddfee 100644 --- a/config/config.go +++ b/config/config.go @@ -62,11 +62,13 @@ type SectionAndroid struct { // SectionIos is sub section of config. type SectionIos struct { - Enabled bool `yaml:"enabled"` - KeyPath string `yaml:"key_path"` - Password string `yaml:"password"` - Production bool `yaml:"production"` - MaxRetry int `yaml:"max_retry"` + Enabled bool `yaml:"enabled"` + KeyPath string `yaml:"key_path"` + Password string `yaml:"password"` + Production bool `yaml:"production"` + MaxRetry int `yaml:"max_retry"` + KeyMap map[string]string `yaml:"key_map"` + KeyPass map[string]string `yaml:"key_password"` } // SectionLog is sub section of config. diff --git a/config/config.yml b/config/config.yml index cdc5a4dc1..58c6aab18 100644 --- a/config/config.yml +++ b/config/config.yml @@ -43,6 +43,14 @@ ios: password: "" # certificate password, default as empty string. production: false max_retry: 0 # resend fail notification, default value zero is disabled + key_map: + cert1: "cert1.pem" + cert2: "cert2.pem" + cert3: "cert3.pem" + key_password: + cert1: "" + cert2: "" + cert3: "" log: format: "string" # string or json diff --git a/config/config_test.go b/config/config_test.go index a4bc43a26..e2dfcbce4 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -162,6 +162,7 @@ func (suite *ConfigTestSuite) TestValidateConf() { assert.Equal(suite.T(), "", suite.ConfGorush.Ios.Password) assert.Equal(suite.T(), false, suite.ConfGorush.Ios.Production) assert.Equal(suite.T(), 0, suite.ConfGorush.Ios.MaxRetry) + assert.Equal(suite.T(), "cert1.pem", suite.ConfGorush.Ios.KeyMap["cert1"]) // log assert.Equal(suite.T(), "string", suite.ConfGorush.Log.Format) diff --git a/gorush/global.go b/gorush/global.go index aadb525f3..525cf014f 100644 --- a/gorush/global.go +++ b/gorush/global.go @@ -28,4 +28,6 @@ var ( LogError *logrus.Logger // StatStorage implements the storage interface StatStorage storage.Storage + // Extend + ApnsClients = make(map[string]*apns.Client, len(PushConf.Ios.KeyMap)) ) diff --git a/gorush/notification.go b/gorush/notification.go index a22f5cf5f..70dbc7c70 100644 --- a/gorush/notification.go +++ b/gorush/notification.go @@ -72,6 +72,7 @@ type PushNotification struct { Notification fcm.Notification `json:"notification,omitempty"` // iOS + ApnsClient string `json:"apns_client,omitempty"` Expiration int64 `json:"expiration,omitempty"` ApnsID string `json:"apns_id,omitempty"` Topic string `json:"topic,omitempty"` diff --git a/gorush/notification_apns.go b/gorush/notification_apns.go index 8b3720e94..e9997d46d 100644 --- a/gorush/notification_apns.go +++ b/gorush/notification_apns.go @@ -11,34 +11,50 @@ import ( ) // InitAPNSClient use for initialize APNs Client. -func InitAPNSClient() error { - if PushConf.Ios.Enabled { - var err error - ext := filepath.Ext(PushConf.Ios.KeyPath) - - switch ext { - case ".p12": - CertificatePemIos, err = certificate.FromP12File(PushConf.Ios.KeyPath, PushConf.Ios.Password) - case ".pem": - CertificatePemIos, err = certificate.FromPemFile(PushConf.Ios.KeyPath, PushConf.Ios.Password) - default: - err = errors.New("wrong certificate key extension") +func InitAPNSClient(key string) (*apns.Client, error) { + path, password := "", "" + if len(key) > 0 { + if _, ok := PushConf.Ios.KeyMap[key]; !ok { + LogError.Errorf("Key %s key_map not exist", key) + return nil, errors.New("APNS key_map not exists") } + if _, ok := PushConf.Ios.KeyMap[key]; !ok { + LogError.Errorf("Key %s key_password not exist", key) + return nil, errors.New("APNS key_password not exists") + } + path = PushConf.Ios.KeyMap[key] + password = PushConf.Ios.KeyMap[key] + } else { + path = PushConf.Ios.KeyPath + password = PushConf.Ios.Password + } - if err != nil { - LogError.Error("Cert Error:", err.Error()) + if c, ok := ApnsClients[path]; ok { + return c, nil + } + var err error + ext := filepath.Ext(path) + switch ext { + case ".p12": + CertificatePemIos, err = certificate.FromP12File(path, password) + case ".pem": + CertificatePemIos, err = certificate.FromPemFile(path, password) + default: + err = errors.New("wrong certificate key extension") + } - return err - } + if err != nil { + LogError.Error("Cert Error:", err.Error()) + return nil, err + } - if PushConf.Ios.Production { - ApnsClient = apns.NewClient(CertificatePemIos).Production() - } else { - ApnsClient = apns.NewClient(CertificatePemIos).Development() - } + if PushConf.Ios.Production { + ApnsClients[path] = apns.NewClient(CertificatePemIos).Production() + } else { + ApnsClients[path] = apns.NewClient(CertificatePemIos).Development() } - return nil + return ApnsClients[path], nil } func iosAlertDictionary(payload *payload.Payload, req PushNotification) *payload.Payload { @@ -162,6 +178,8 @@ func PushToIOS(req PushNotification) bool { } var ( + err error + client *apns.Client retryCount = 0 maxRetry = PushConf.Ios.MaxRetry ) @@ -181,8 +199,19 @@ Retry: for _, token := range req.Tokens { notification.DeviceToken = token + if req.ApnsClient != "" { + client, err = InitAPNSClient(req.ApnsClient) + } else { + client, err = InitAPNSClient("") + } + + if err != nil { + // APNS server error + LogError.Error("APNS server error: " + err.Error()) + return false + } // send ios notification - res, err := ApnsClient.Push(notification) + res, err := client.Push(notification) if err != nil { // apns server error diff --git a/gorush/notification_apns_test.go b/gorush/notification_apns_test.go index 90cfe5492..bade01360 100644 --- a/gorush/notification_apns_test.go +++ b/gorush/notification_apns_test.go @@ -349,7 +349,7 @@ func TestDisabledIosNotifications(t *testing.T) { PushConf.Ios.Enabled = false PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() + _, err := InitAPNSClient("") assert.Nil(t, err) PushConf.Android.Enabled = true @@ -384,7 +384,7 @@ func TestWrongIosCertificateExt(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.KeyPath = "test" - err := InitAPNSClient() + _, err := InitAPNSClient("") assert.Error(t, err) assert.Equal(t, "wrong certificate key extension", err.Error()) @@ -395,9 +395,10 @@ func TestAPNSClientDevHost(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.KeyPath = "../certificate/certificate-valid.p12" - err := InitAPNSClient() + ApnsClients = make(map[string]*apns2.Client, len(PushConf.Ios.KeyMap)) + client, err := InitAPNSClient("") assert.Nil(t, err) - assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host) + assert.Equal(t, apns2.HostDevelopment, client.Host) } func TestAPNSClientProdHost(t *testing.T) { @@ -406,9 +407,10 @@ func TestAPNSClientProdHost(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.Production = true PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() + ApnsClients = make(map[string]*apns2.Client, len(PushConf.Ios.KeyMap)) + client, err := InitAPNSClient("") assert.Nil(t, err) - assert.Equal(t, apns2.HostProduction, ApnsClient.Host) + assert.Equal(t, apns2.HostProduction, client.Host) } func TestPushToIOS(t *testing.T) { @@ -416,7 +418,7 @@ func TestPushToIOS(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() + _, err := InitAPNSClient("") assert.Nil(t, err) err = InitAppStatus() assert.Nil(t, err) @@ -431,3 +433,48 @@ func TestPushToIOS(t *testing.T) { isError := PushToIOS(req) assert.True(t, isError) } + +func TestProvideApnsClient(t *testing.T) { + PushConf = config.BuildDefaultPushConf() + + PushConf.Ios.Enabled = true + // PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" + ApnsClients = make(map[string]*apns2.Client, len(PushConf.Ios.KeyMap)) + PushConf.Ios.KeyMap = map[string]string{"cert1": "../certificate/certificate-valid.pem"} + PushConf.Ios.KeyPass = map[string]string{"cert1": ""} + + err := InitAppStatus() + assert.Nil(t, err) + + req := PushNotification{ + Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, + Platform: 1, + Message: "Welcome", + ApnsClient: "cert1", + } + + // send success + isError := PushToIOS(req) + assert.True(t, isError) +} + +func TestProvideWrongApnsClient(t *testing.T) { + PushConf = config.BuildDefaultPushConf() + + PushConf.Ios.Enabled = true + PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" + ApnsClients = make(map[string]*apns2.Client, len(PushConf.Ios.KeyMap)) + err := InitAppStatus() + assert.Nil(t, err) + + req := PushNotification{ + Tokens: []string{"11aa01229f15f0f0c52029d8cf8cd0aeaf2365fe4cebc4af26cd6d76b7919ef7"}, + Platform: 1, + Message: "Welcome", + ApnsClient: "cert_not_exist", + } + + // send fail + isError := PushToIOS(req) + assert.False(t, isError) +} diff --git a/gorush/notification_test.go b/gorush/notification_test.go index f4d454dd7..e25b9f1b3 100644 --- a/gorush/notification_test.go +++ b/gorush/notification_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/appleboy/gorush/config" + "github.com/sideshow/apns2" "github.com/stretchr/testify/assert" ) @@ -29,8 +30,6 @@ func TestSenMultipleNotifications(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) PushConf.Android.Enabled = true PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") @@ -64,8 +63,6 @@ func TestDisabledAndroidNotifications(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) PushConf.Android.Enabled = false PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") @@ -99,8 +96,7 @@ func TestSyncModeForNotifications(t *testing.T) { PushConf.Ios.Enabled = true PushConf.Ios.KeyPath = "../certificate/certificate-valid.pem" - err := InitAPNSClient() - assert.Nil(t, err) + ApnsClients = make(map[string]*apns2.Client, len(PushConf.Ios.KeyMap)) PushConf.Android.Enabled = true PushConf.Android.APIKey = os.Getenv("ANDROID_API_KEY") diff --git a/gorush/worker.go b/gorush/worker.go index 734d1f74e..300254395 100644 --- a/gorush/worker.go +++ b/gorush/worker.go @@ -1,6 +1,7 @@ package gorush import ( + "fmt" "sync" ) @@ -41,6 +42,11 @@ func queueNotification(req RequestPush) (int, []LogPushEntry) { if !PushConf.Ios.Enabled { continue } + if len(notification.ApnsClient) > 0 { + if _, ok := PushConf.Ios.KeyMap[notification.ApnsClient]; !ok { + continue + } + } case PlatFormAndroid: if !PushConf.Android.Enabled { continue @@ -56,6 +62,7 @@ func queueNotification(req RequestPush) (int, []LogPushEntry) { notification.log = &log notification.AddWaitCount() } + fmt.Println("-----") QueueNotification <- notification count += len(notification.Tokens) } diff --git a/main.go b/main.go index d9fd52ffc..6b4076655 100644 --- a/main.go +++ b/main.go @@ -189,7 +189,7 @@ func main() { return } - if err := gorush.InitAPNSClient(); err != nil { + if _, err := gorush.InitAPNSClient(""); err != nil { return } gorush.PushToIOS(req) @@ -220,7 +220,8 @@ func main() { var g errgroup.Group g.Go(func() error { - return gorush.InitAPNSClient() + _, err := gorush.InitAPNSClient("") + return err }) g.Go(func() error {