From ce7604eab29b71af4ad5b158f18bbdcea2f35333 Mon Sep 17 00:00:00 2001 From: Markus Wennrich Date: Wed, 29 May 2024 15:38:43 +0200 Subject: [PATCH] use only the nguid part of the volumeHandle for lightbits volumes (#33) --- pkg/metal/core.go | 40 +++++++++++++++++++---- pkg/metal/core_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 pkg/metal/core_test.go diff --git a/pkg/metal/core.go b/pkg/metal/core.go index 4767ec0..3a61059 100644 --- a/pkg/metal/core.go +++ b/pkg/metal/core.go @@ -349,21 +349,47 @@ func (p *Provider) ListMachines(ctx context.Context, req *driver.ListMachinesReq // // RESPONSE PARAMETERS (driver.GetVolumeIDsResponse) // VolumeIDs []string VolumeIDs is a repeated list of VolumeIDs. -func (p *Provider) GetVolumeIDs(ctx context.Context, req *driver.GetVolumeIDsRequest) (*driver.GetVolumeIDsResponse, error) { +func (p *Provider) GetVolumeIDs(_ context.Context, req *driver.GetVolumeIDsRequest) (*driver.GetVolumeIDsResponse, error) { // Log messages to track start and end of request klog.V(2).Infof("GetVolumeIDs request has been received for %q", req.PVSpecs) - volumeIDs := []string{} - specs := req.PVSpecs - for i := range specs { - spec := specs[i] - if spec.CSI == nil { + + var ( + volumeIDs []string + ) + + for _, spec := range req.PVSpecs { + if spec == nil || spec.CSI == nil { // Not a CSI volume continue } - volumeIDs = append(volumeIDs, spec.CSI.VolumeHandle) + switch spec.CSI.Driver { + case "csi.lightbitslabs.com": + fields := map[string]string{} + for _, part := range strings.Split(spec.CSI.VolumeHandle, "|") { + k, v, ok := strings.Cut(part, ":") + if !ok { + continue + } + fields[k] = v + } + + nguid, ok := fields["nguid"] + if ok { + volumeIDs = append(volumeIDs, nguid) + continue + } + + klog.Errorf("invalid lightbits volumeHandle (missing nguid): %s", spec.CSI.VolumeHandle) + + fallthrough + default: + volumeIDs = append(volumeIDs, spec.CSI.VolumeHandle) + } } + klog.V(2).Infof("GetVolumeIDs request has been processed successfully for %q", req.PVSpecs) + return &driver.GetVolumeIDsResponse{VolumeIDs: volumeIDs}, nil } diff --git a/pkg/metal/core_test.go b/pkg/metal/core_test.go new file mode 100644 index 0000000..3ff159b --- /dev/null +++ b/pkg/metal/core_test.go @@ -0,0 +1,73 @@ +// Package provider contains the cloud provider specific implementations to manage machines +package provider + +import ( + "context" + "testing" + + "github.com/gardener/machine-controller-manager/pkg/util/provider/driver" + "github.com/google/go-cmp/cmp" + corev1 "k8s.io/api/core/v1" +) + +func TestProvider_GetVolumeIDs(t *testing.T) { + tests := []struct { + name string + req *driver.GetVolumeIDsRequest + want *driver.GetVolumeIDsResponse + wantErr error + }{ + { + name: "valid lightbits volume", + req: &driver.GetVolumeIDsRequest{ + PVSpecs: []*corev1.PersistentVolumeSpec{ + { + + PersistentVolumeSource: corev1.PersistentVolumeSource{ + CSI: &corev1.CSIPersistentVolumeSource{ + Driver: "csi.lightbitslabs.com", + VolumeHandle: "mgmt:10.131.44.1:443,10.131.44.2:443,10.131.44.3:443|nguid:d22572da-a225-4578-ab1a-9318ac5155c3|proj:cd4eac58-46a5-4a31-b59f-2ec207baa817|scheme:grpcs", + }, + }, + }, + }, + }, + want: &driver.GetVolumeIDsResponse{ + VolumeIDs: []string{"d22572da-a225-4578-ab1a-9318ac5155c3"}, + }, + }, + { + name: "invalid lightbits volume", + req: &driver.GetVolumeIDsRequest{ + PVSpecs: []*corev1.PersistentVolumeSpec{ + { + + PersistentVolumeSource: corev1.PersistentVolumeSource{ + CSI: &corev1.CSIPersistentVolumeSource{ + Driver: "csi.lightbitslabs.com", + VolumeHandle: "mgmt:10.131.44.1:443,10.131.44.2:443,10.131.44.3:443|proj:cd4eac58-46a5-4a31-b59f-2ec207baa817|scheme:grpcs", + }, + }, + }, + }, + }, + want: &driver.GetVolumeIDsResponse{ + VolumeIDs: []string{"mgmt:10.131.44.1:443,10.131.44.2:443,10.131.44.3:443|proj:cd4eac58-46a5-4a31-b59f-2ec207baa817|scheme:grpcs"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Provider{} + + got, err := p.GetVolumeIDs(context.Background(), tt.req) + + if diff := cmp.Diff(tt.wantErr, err); diff != "" { + t.Errorf("err diff = %s", diff) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("diff = %s", diff) + } + }) + } +}