Merge pull request #7 from Selly-Modules/support-uri-with-options

support uri with options
This commit is contained in:
Sinh Luu 2022-06-16 14:22:19 +07:00 committed by GitHub
commit 438cfd8b9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 3 deletions

View File

@ -3,6 +3,7 @@ package mongodb
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/logrusorgru/aurora" "github.com/logrusorgru/aurora"
"go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/event"
@ -48,7 +49,7 @@ func Connect(cfg Config) (*mongo.Database, error) {
connectOptions := options.ClientOptions{} connectOptions := options.ClientOptions{}
opts := cfg.Standalone opts := cfg.Standalone
// Set auth if existed // Set auth if existed
if opts.Username != "" && opts.Password != "" { if opts != nil && opts.Username != "" && opts.Password != "" {
connectOptions.Auth = &options.Credential{ connectOptions.Auth = &options.Credential{
AuthMechanism: opts.AuthMechanism, AuthMechanism: opts.AuthMechanism,
AuthSource: opts.AuthSource, AuthSource: opts.AuthSource,
@ -87,8 +88,7 @@ func connectWithTLS(cfg Config) (*mongo.Database, error) {
return nil, err return nil, err
} }
pwd := base64DecodeToString(opts.CertKeyFilePassword) pwd := base64DecodeToString(opts.CertKeyFilePassword)
s := "%s/?tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509" uri := getURIWithTLS(cfg, caFile.Name(), certFile.Name(), pwd)
uri := fmt.Sprintf(s, cfg.Host, caFile.Name(), certFile.Name(), pwd)
readPref := getReadPref(opts.ReadPreferenceMode) readPref := getReadPref(opts.ReadPreferenceMode)
clientOpts := options.Client().SetReadPreference(readPref).SetReplicaSet(opts.ReplSet).ApplyURI(uri) clientOpts := options.Client().SetReadPreference(readPref).SetReplicaSet(opts.ReplSet).ApplyURI(uri)
if cfg.Monitor != nil { if cfg.Monitor != nil {
@ -107,6 +107,22 @@ func connectWithTLS(cfg Config) (*mongo.Database, error) {
return db, err return db, err
} }
func getURIWithTLS(cfg Config, caFilePath, certFilePath, pwd string) string {
host := cfg.Host
if strings.Contains(host, "?") {
host += "&"
} else {
if !strings.HasSuffix(host, "/") {
host += "/?"
} else {
host += "?"
}
}
s := "%stls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509"
uri := fmt.Sprintf(s, host, caFilePath, certFilePath, pwd)
return uri
}
// GetInstance ... // GetInstance ...
func GetInstance() *mongo.Database { func GetInstance() *mongo.Database {
return db return db

View File

@ -1,6 +1,7 @@
package mongodb package mongodb
import ( import (
"fmt"
"testing" "testing"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
@ -45,3 +46,58 @@ func Test_connectWithTLS(t *testing.T) {
}) })
} }
} }
func Test_getURIWithTLS(t *testing.T) {
type args struct {
cfg Config
caFilePath string
certFilePath string
pwd string
}
ca := "ca.pem"
cert := "cert.pem"
pwd := "1"
tests := []struct {
name string
args args
want string
}{
{
name: "uri no options",
args: args{
cfg: Config{Host: "mongodb://localhost:27017"},
caFilePath: ca,
certFilePath: cert,
pwd: pwd,
},
want: fmt.Sprintf("mongodb://localhost:27017/?tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509", ca, cert, pwd),
},
{
name: "uri no options, end with /",
args: args{
cfg: Config{Host: "mongodb://localhost:27017/"},
caFilePath: ca,
certFilePath: cert,
pwd: pwd,
},
want: fmt.Sprintf("mongodb://localhost:27017/?tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509", ca, cert, pwd),
},
{
name: "uri has options",
args: args{
cfg: Config{Host: "mongodb://localhost:27017/?a=1"},
caFilePath: ca,
certFilePath: cert,
pwd: pwd,
},
want: fmt.Sprintf("mongodb://localhost:27017/?a=1&tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509", ca, cert, pwd),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getURIWithTLS(tt.args.cfg, tt.args.caFilePath, tt.args.certFilePath, tt.args.pwd); got != tt.want {
t.Errorf("getURIWithTLS() = %v, want %v", got, tt.want)
}
})
}
}