From 3f39cbd5063ab7971ac6d5815a583ad092e9651e Mon Sep 17 00:00:00 2001 From: andyzhangx Date: Sat, 18 Jan 2025 13:09:41 +0000 Subject: [PATCH] fix: support base64password field in secret fix fix fix fix --- Makefile | 2 +- pkg/smb/nodeserver.go | 13 ++++++++++++- pkg/smb/nodeserver_test.go | 17 +++++++++++++++++ pkg/smb/smb.go | 11 +++++++++++ pkg/smb/smb_test.go | 28 ++++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 033d1e161a7..23eecfd75af 100644 --- a/Makefile +++ b/Makefile @@ -246,7 +246,7 @@ endif .PHONY: install-smb-provisioner install-smb-provisioner: kubectl delete secret smbcreds --ignore-not-found -n default - kubectl create secret generic smbcreds --from-literal username=USERNAME --from-literal password="PASSWORD" --from-literal mountOptions="dir_mode=0777,file_mode=0777,uid=0,gid=0,mfsymlinks" -n default + kubectl create secret generic smbcreds --from-literal username=USERNAME --from-literal password="PASSWORD" --from-literal base64password="UEFTU1dPUkQ=" --from-literal mountOptions="dir_mode=0777,file_mode=0777,uid=0,gid=0,mfsymlinks" -n default ifdef TEST_WINDOWS kubectl apply -f deploy/example/smb-provisioner/smb-server-lb.yaml else diff --git a/pkg/smb/nodeserver.go b/pkg/smb/nodeserver.go index 0528003ea01..96f78637cc4 100644 --- a/pkg/smb/nodeserver.go +++ b/pkg/smb/nodeserver.go @@ -183,7 +183,7 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe } defer d.volumeLocks.Release(lockKey) - var username, password, domain string + var username, password, base64Password, domain string for k, v := range secrets { switch strings.ToLower(k) { case usernameField: @@ -192,9 +192,20 @@ func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRe password = strings.TrimSpace(v) case domainField: domain = strings.TrimSpace(v) + case base64PasswordField: + base64Password = strings.TrimSpace(v) } } + if base64Password != "" { + klog.V(2).Infof("NodeStageVolume: decoding password from base64 string") + decodePassword, err := base64.StdEncoding.DecodeString(base64Password) + if err != nil { + return nil, status.Error(codes.InvalidArgument, "error base64 decoding password") + } + password = string(decodePassword) + } + if ephemeralVol { mountFlags = strings.Split(ephemeralVolMountOptions, ",") } diff --git a/pkg/smb/nodeserver_test.go b/pkg/smb/nodeserver_test.go index 4b87fe1a987..ea914196f8f 100644 --- a/pkg/smb/nodeserver_test.go +++ b/pkg/smb/nodeserver_test.go @@ -92,6 +92,11 @@ func TestNodeStageVolume(t *testing.T) { passwordField: "test_password", domainField: "test_doamin", } + secretsWithBase64Password := map[string]string{ + usernameField: "test_username", + passwordField: base64.StdEncoding.EncodeToString([]byte("test_password")), + domainField: "test_doamin", + } tests := []struct { desc string @@ -230,6 +235,18 @@ func TestNodeStageVolume(t *testing.T) { strings.Replace(testSource, "\\", "\\\\", -1), sourceTest, testSource, sourceTest), expectedErr: testutil.TestError{}, }, + { + desc: "[Success] Valid request with base64 encoded password", + req: &csi.NodeStageVolumeRequest{VolumeId: "vol_1##", StagingTargetPath: sourceTest, + VolumeCapability: &stdVolCap, + VolumeContext: volContext, + Secrets: secretsWithBase64Password}, + skipOnWindows: true, + flakyWindowsErrorMessage: fmt.Sprintf("rpc error: code = Internal desc = volume(vol_1##) mount \"%s\" on %#v failed with "+ + "NewSmbGlobalMapping(%s, %s) failed with error: rpc error: code = Unknown desc = NewSmbGlobalMapping failed.", + strings.Replace(testSource, "\\", "\\\\", -1), sourceTest, testSource, sourceTest), + expectedErr: testutil.TestError{}, + }, } // Setup diff --git a/pkg/smb/smb.go b/pkg/smb/smb.go index d0f2365708b..1dcfe83d434 100644 --- a/pkg/smb/smb.go +++ b/pkg/smb/smb.go @@ -18,6 +18,7 @@ package smb import ( "context" + "encoding/base64" "errors" "fmt" "net" @@ -49,6 +50,7 @@ const ( sourceField = "source" subDirField = "subdir" domainField = "domain" + base64PasswordField = "base64password" mountOptionsField = "mountoptions" secretNameField = "secretname" secretNamespaceField = "secretnamespace" @@ -232,6 +234,15 @@ func (d *Driver) GetUserNamePasswordFromSecret(ctx context.Context, secretName, username := strings.TrimSpace(string(secret.Data[usernameField][:])) password := strings.TrimSpace(string(secret.Data[passwordField][:])) domain := strings.TrimSpace(string(secret.Data[domainField][:])) + base64Password := strings.TrimSpace(string(secret.Data[base64PasswordField][:])) + if base64Password != "" { + klog.V(2).Infof("decoding password from base64 string") + decodePassword, err := base64.StdEncoding.DecodeString(base64Password) + if err != nil { + return "", "", "", fmt.Errorf("could not decode password from base64 string: %v", err) + } + password = string(decodePassword) + } return username, password, domain, nil } diff --git a/pkg/smb/smb_test.go b/pkg/smb/smb_test.go index 21761cbbd49..ff01686dbd7 100644 --- a/pkg/smb/smb_test.go +++ b/pkg/smb/smb_test.go @@ -17,6 +17,7 @@ limitations under the License. package smb import ( + "context" "fmt" "os" "path/filepath" @@ -520,6 +521,33 @@ users: } } +func TestGetUserNamePasswordFromSecret(t *testing.T) { + tests := []struct { + desc string + secretName string + secretNamespace string + expectedUsername string + expectedPassword string + expectedDomain string + expectedError error + }{ + { + desc: "kubeclient is nil", + secretName: "secretName", + expectedError: fmt.Errorf("could not username and password from secret(secretName): KubeClient is nil"), + }, + } + + d := NewFakeDriver() + for _, test := range tests { + username, password, domain, err := d.GetUserNamePasswordFromSecret(context.Background(), test.secretName, test.secretNamespace) + assert.Equal(t, test.expectedUsername, username, "test[%s]: unexpected username", test.desc) + assert.Equal(t, test.expectedPassword, password, "test[%s]: unexpected password", test.desc) + assert.Equal(t, test.expectedDomain, domain, "test[%s]: unexpected domain", test.desc) + assert.Equal(t, test.expectedError, err, "test[%s]: unexpected error", test.desc) + } +} + func createTestFile(path string) error { f, err := os.Create(path) if err != nil {