From 2b279769e6d4a3d664b4a4e7e097bcc764808101 Mon Sep 17 00:00:00 2001 From: Sinh Date: Thu, 16 Jun 2022 14:21:34 +0700 Subject: [PATCH] support uri with options --- mongodb.go | 22 ++++++++++++++++--- mongodb_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/mongodb.go b/mongodb.go index 3786445..bdc640b 100644 --- a/mongodb.go +++ b/mongodb.go @@ -3,6 +3,7 @@ package mongodb import ( "context" "fmt" + "strings" "github.com/logrusorgru/aurora" "go.mongodb.org/mongo-driver/event" @@ -48,7 +49,7 @@ func Connect(cfg Config) (*mongo.Database, error) { connectOptions := options.ClientOptions{} opts := cfg.Standalone // Set auth if existed - if opts.Username != "" && opts.Password != "" { + if opts != nil && opts.Username != "" && opts.Password != "" { connectOptions.Auth = &options.Credential{ AuthMechanism: opts.AuthMechanism, AuthSource: opts.AuthSource, @@ -87,8 +88,7 @@ func connectWithTLS(cfg Config) (*mongo.Database, error) { return nil, err } pwd := base64DecodeToString(opts.CertKeyFilePassword) - s := "%s/?tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509" - uri := fmt.Sprintf(s, cfg.Host, caFile.Name(), certFile.Name(), pwd) + uri := getURIWithTLS(cfg, caFile.Name(), certFile.Name(), pwd) readPref := getReadPref(opts.ReadPreferenceMode) clientOpts := options.Client().SetReadPreference(readPref).SetReplicaSet(opts.ReplSet).ApplyURI(uri) if cfg.Monitor != nil { @@ -107,6 +107,22 @@ func connectWithTLS(cfg Config) (*mongo.Database, error) { return db, err } +func getURIWithTLS(cfg Config, caFilePath, certFilePath, pwd string) string { + host := cfg.Host + if strings.Contains(host, "?") { + host += "&" + } else { + if !strings.HasSuffix(host, "/") { + host += "/?" + } else { + host += "?" + } + } + s := "%stls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509" + uri := fmt.Sprintf(s, host, caFilePath, certFilePath, pwd) + return uri +} + // GetInstance ... func GetInstance() *mongo.Database { return db diff --git a/mongodb_test.go b/mongodb_test.go index ae6273b..950303f 100644 --- a/mongodb_test.go +++ b/mongodb_test.go @@ -1,6 +1,7 @@ package mongodb import ( + "fmt" "testing" "go.mongodb.org/mongo-driver/mongo" @@ -45,3 +46,58 @@ func Test_connectWithTLS(t *testing.T) { }) } } + +func Test_getURIWithTLS(t *testing.T) { + type args struct { + cfg Config + caFilePath string + certFilePath string + pwd string + } + ca := "ca.pem" + cert := "cert.pem" + pwd := "1" + tests := []struct { + name string + args args + want string + }{ + { + name: "uri no options", + args: args{ + cfg: Config{Host: "mongodb://localhost:27017"}, + caFilePath: ca, + certFilePath: cert, + pwd: pwd, + }, + want: fmt.Sprintf("mongodb://localhost:27017/?tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509", ca, cert, pwd), + }, + { + name: "uri no options, end with /", + args: args{ + cfg: Config{Host: "mongodb://localhost:27017/"}, + caFilePath: ca, + certFilePath: cert, + pwd: pwd, + }, + want: fmt.Sprintf("mongodb://localhost:27017/?tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509", ca, cert, pwd), + }, + { + name: "uri has options", + args: args{ + cfg: Config{Host: "mongodb://localhost:27017/?a=1"}, + caFilePath: ca, + certFilePath: cert, + pwd: pwd, + }, + want: fmt.Sprintf("mongodb://localhost:27017/?a=1&tls=true&tlsCAFile=./%s&tlsCertificateKeyFile=./%s&tlsCertificateKeyFilePassword=%s&authMechanism=MONGODB-X509", ca, cert, pwd), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getURIWithTLS(tt.args.cfg, tt.args.caFilePath, tt.args.certFilePath, tt.args.pwd); got != tt.want { + t.Errorf("getURIWithTLS() = %v, want %v", got, tt.want) + } + }) + } +}