From 9ea254532a4dbd325221c546684804abdf6a8cbe Mon Sep 17 00:00:00 2001 From: carlory Date: Tue, 13 Jun 2023 00:29:07 +0800 Subject: [PATCH] Add GetGroupControllerCapabilities common function --- rpc/common.go | 27 ++++++++++ rpc/common_test.go | 130 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 152 insertions(+), 5 deletions(-) diff --git a/rpc/common.go b/rpc/common.go index 9dcb3534..263bfba3 100644 --- a/rpc/common.go +++ b/rpc/common.go @@ -104,6 +104,33 @@ func GetControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (Cont return caps, nil } +// GroupControllerCapabilitySet is set of CSI groupcontroller capabilities. Only supported capabilities are in the map. +type GroupControllerCapabilitySet map[csi.GroupControllerServiceCapability_RPC_Type]bool + +// GetGroupControllerCapabilities returns set of supported group controller capabilities of CSI driver. +func GetGroupControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (GroupControllerCapabilitySet, error) { + client := csi.NewGroupControllerClient(conn) + req := csi.GroupControllerGetCapabilitiesRequest{} + rsp, err := client.GroupControllerGetCapabilities(ctx, &req) + if err != nil { + return nil, err + } + + caps := GroupControllerCapabilitySet{} + for _, cap := range rsp.GetCapabilities() { + if cap == nil { + continue + } + rpc := cap.GetRpc() + if rpc == nil { + continue + } + t := rpc.GetType() + caps[t] = true + } + return caps, nil +} + // ProbeForever calls Probe() of a CSI driver and waits until the driver becomes ready. // Any error other than timeout is returned. func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error { diff --git a/rpc/common_test.go b/rpc/common_test.go index ebba02ef..cac0f118 100644 --- a/rpc/common_test.go +++ b/rpc/common_test.go @@ -52,7 +52,7 @@ const ( // startServer creates a gRPC server without any registered services. // The returned address can be used to connect to it. The cleanup // function stops it. It can be called multiple times. -func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer) (string, func()) { +func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer, groupCtrl csi.GroupControllerServer) (string, func()) { addr := path.Join(tmp, serverSock) listener, err := net.Listen("unix", addr) require.NoError(t, err, "listening on %s", addr) @@ -63,6 +63,9 @@ func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controll if controller != nil { csi.RegisterControllerServer(server, controller) } + if groupCtrl != nil { + csi.RegisterGroupControllerServer(server, groupCtrl) + } var wg sync.WaitGroup wg.Add(1) go func() { @@ -127,7 +130,7 @@ func TestGetDriverName(t *testing.T) { pluginInfoResponse: out, err: injectedErr, } - addr, stopServer := startServer(t, tmp, identity, nil) + addr, stopServer := startServer(t, tmp, identity, nil, nil) defer func() { stopServer() }() @@ -247,7 +250,7 @@ func TestGetPluginCapabilities(t *testing.T) { // and 1.11.1 (which will be used by new Prow job) via an extra blank line. err: injectedErr, } - addr, stopServer := startServer(t, tmp, identity, nil) + addr, stopServer := startServer(t, tmp, identity, nil, nil) defer func() { stopServer() }() @@ -375,7 +378,7 @@ func TestGetControllerCapabilities(t *testing.T) { // and 1.11.1 (which will be used by new Prow job) via an extra blank line. err: injectedErr, } - addr, stopServer := startServer(t, tmp, nil, controller) + addr, stopServer := startServer(t, tmp, nil, controller, nil) defer func() { stopServer() }() @@ -399,6 +402,100 @@ func TestGetControllerCapabilities(t *testing.T) { } } +func TestGetGroupControllerCapabilities(t *testing.T) { + tests := []struct { + name string + output *csi.GroupControllerGetCapabilitiesResponse + injectError bool + expectCapabilities GroupControllerCapabilitySet + expectError bool + }{ + { + name: "success", + output: &csi.GroupControllerGetCapabilitiesResponse{ + Capabilities: []*csi.GroupControllerServiceCapability{ + { + Type: &csi.GroupControllerServiceCapability_Rpc{ + Rpc: &csi.GroupControllerServiceCapability_RPC{ + Type: csi.GroupControllerServiceCapability_RPC_CREATE_DELETE_GET_VOLUME_GROUP_SNAPSHOT, + }, + }, + }, + }, + }, + expectCapabilities: GroupControllerCapabilitySet{ + csi.GroupControllerServiceCapability_RPC_CREATE_DELETE_GET_VOLUME_GROUP_SNAPSHOT: true, + }, + expectError: false, + }, + { + name: "gRPC error", + output: nil, + injectError: true, + expectError: true, + }, + { + name: "empty capability", + output: &csi.GroupControllerGetCapabilitiesResponse{ + Capabilities: []*csi.GroupControllerServiceCapability{ + { + Type: nil, + }, + }, + }, + expectCapabilities: GroupControllerCapabilitySet{}, + expectError: false, + }, + { + name: "no capabilities", + output: &csi.GroupControllerGetCapabilitiesResponse{ + Capabilities: []*csi.GroupControllerServiceCapability{}, + }, + expectCapabilities: GroupControllerCapabilitySet{}, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var injectedErr error + if test.injectError { + injectedErr = fmt.Errorf("mock error") + } + + tmp := tmpDir(t) + defer os.RemoveAll(tmp) + groupCtrl := &fakeGroupControllerServer{ + groupControllerGetCapabilitiesResponse: test.output, + + // Make code compatible with gofmt 1.10.2 (used by pull-sig-storage-csi-lib-utils-stable) + // and 1.11.1 (which will be used by new Prow job) via an extra blank line. + err: injectedErr, + } + addr, stopServer := startServer(t, tmp, nil, nil, groupCtrl) + defer func() { + stopServer() + }() + + conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io")) + if err != nil { + t.Fatalf("Failed to connect to CSI driver: %s", err) + } + + caps, err := GetGroupControllerCapabilities(context.Background(), conn) + if test.expectError && err == nil { + t.Errorf("Expected error, got none") + } + if !test.expectError && err != nil { + t.Errorf("Got error: %v", err) + } + if !reflect.DeepEqual(test.expectCapabilities, caps) { + t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps) + } + }) + } +} + func TestProbeForever(t *testing.T) { tests := []struct { name string @@ -509,7 +606,7 @@ func TestProbeForever(t *testing.T) { identity := &fakeIdentityServer{ probeCalls: test.probeCalls, } - addr, stopServer := startServer(t, tmp, identity, nil) + addr, stopServer := startServer(t, tmp, identity, nil, nil) defer func() { stopServer() }() @@ -624,3 +721,26 @@ func (c *fakeControllerServer) ListSnapshots(context.Context, *csi.ListSnapshots func (c *fakeControllerServer) ControllerExpandVolume(context.Context, *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) { return nil, fmt.Errorf("unimplemented") } + +type fakeGroupControllerServer struct { + groupControllerGetCapabilitiesResponse *csi.GroupControllerGetCapabilitiesResponse + err error +} + +var _ csi.GroupControllerServer = &fakeGroupControllerServer{} + +func (c *fakeGroupControllerServer) GroupControllerGetCapabilities(context.Context, *csi.GroupControllerGetCapabilitiesRequest) (*csi.GroupControllerGetCapabilitiesResponse, error) { + return c.groupControllerGetCapabilitiesResponse, c.err +} + +func (c *fakeGroupControllerServer) CreateVolumeGroupSnapshot(context.Context, *csi.CreateVolumeGroupSnapshotRequest) (*csi.CreateVolumeGroupSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeGroupControllerServer) DeleteVolumeGroupSnapshot(context.Context, *csi.DeleteVolumeGroupSnapshotRequest) (*csi.DeleteVolumeGroupSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +} + +func (c *fakeGroupControllerServer) GetVolumeGroupSnapshot(context.Context, *csi.GetVolumeGroupSnapshotRequest) (*csi.GetVolumeGroupSnapshotResponse, error) { + return nil, fmt.Errorf("unimplemented") +}