Skip to content

Commit 1b2426d

Browse files
authored
Merge pull request #133 from carlory/patch-002
Add GetGroupControllerCapabilities common function
2 parents a0d716c + 9ea2545 commit 1b2426d

File tree

2 files changed

+152
-5
lines changed

2 files changed

+152
-5
lines changed

rpc/common.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,33 @@ func GetControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (Cont
104104
return caps, nil
105105
}
106106

107+
// GroupControllerCapabilitySet is set of CSI groupcontroller capabilities. Only supported capabilities are in the map.
108+
type GroupControllerCapabilitySet map[csi.GroupControllerServiceCapability_RPC_Type]bool
109+
110+
// GetGroupControllerCapabilities returns set of supported group controller capabilities of CSI driver.
111+
func GetGroupControllerCapabilities(ctx context.Context, conn *grpc.ClientConn) (GroupControllerCapabilitySet, error) {
112+
client := csi.NewGroupControllerClient(conn)
113+
req := csi.GroupControllerGetCapabilitiesRequest{}
114+
rsp, err := client.GroupControllerGetCapabilities(ctx, &req)
115+
if err != nil {
116+
return nil, err
117+
}
118+
119+
caps := GroupControllerCapabilitySet{}
120+
for _, cap := range rsp.GetCapabilities() {
121+
if cap == nil {
122+
continue
123+
}
124+
rpc := cap.GetRpc()
125+
if rpc == nil {
126+
continue
127+
}
128+
t := rpc.GetType()
129+
caps[t] = true
130+
}
131+
return caps, nil
132+
}
133+
107134
// ProbeForever calls Probe() of a CSI driver and waits until the driver becomes ready.
108135
// Any error other than timeout is returned.
109136
func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error {

rpc/common_test.go

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const (
5252
// startServer creates a gRPC server without any registered services.
5353
// The returned address can be used to connect to it. The cleanup
5454
// function stops it. It can be called multiple times.
55-
func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer) (string, func()) {
55+
func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controller csi.ControllerServer, groupCtrl csi.GroupControllerServer) (string, func()) {
5656
addr := path.Join(tmp, serverSock)
5757
listener, err := net.Listen("unix", addr)
5858
require.NoError(t, err, "listening on %s", addr)
@@ -63,6 +63,9 @@ func startServer(t *testing.T, tmp string, identity csi.IdentityServer, controll
6363
if controller != nil {
6464
csi.RegisterControllerServer(server, controller)
6565
}
66+
if groupCtrl != nil {
67+
csi.RegisterGroupControllerServer(server, groupCtrl)
68+
}
6669
var wg sync.WaitGroup
6770
wg.Add(1)
6871
go func() {
@@ -127,7 +130,7 @@ func TestGetDriverName(t *testing.T) {
127130
pluginInfoResponse: out,
128131
err: injectedErr,
129132
}
130-
addr, stopServer := startServer(t, tmp, identity, nil)
133+
addr, stopServer := startServer(t, tmp, identity, nil, nil)
131134
defer func() {
132135
stopServer()
133136
}()
@@ -247,7 +250,7 @@ func TestGetPluginCapabilities(t *testing.T) {
247250
// and 1.11.1 (which will be used by new Prow job) via an extra blank line.
248251
err: injectedErr,
249252
}
250-
addr, stopServer := startServer(t, tmp, identity, nil)
253+
addr, stopServer := startServer(t, tmp, identity, nil, nil)
251254
defer func() {
252255
stopServer()
253256
}()
@@ -375,7 +378,7 @@ func TestGetControllerCapabilities(t *testing.T) {
375378
// and 1.11.1 (which will be used by new Prow job) via an extra blank line.
376379
err: injectedErr,
377380
}
378-
addr, stopServer := startServer(t, tmp, nil, controller)
381+
addr, stopServer := startServer(t, tmp, nil, controller, nil)
379382
defer func() {
380383
stopServer()
381384
}()
@@ -399,6 +402,100 @@ func TestGetControllerCapabilities(t *testing.T) {
399402
}
400403
}
401404

405+
func TestGetGroupControllerCapabilities(t *testing.T) {
406+
tests := []struct {
407+
name string
408+
output *csi.GroupControllerGetCapabilitiesResponse
409+
injectError bool
410+
expectCapabilities GroupControllerCapabilitySet
411+
expectError bool
412+
}{
413+
{
414+
name: "success",
415+
output: &csi.GroupControllerGetCapabilitiesResponse{
416+
Capabilities: []*csi.GroupControllerServiceCapability{
417+
{
418+
Type: &csi.GroupControllerServiceCapability_Rpc{
419+
Rpc: &csi.GroupControllerServiceCapability_RPC{
420+
Type: csi.GroupControllerServiceCapability_RPC_CREATE_DELETE_GET_VOLUME_GROUP_SNAPSHOT,
421+
},
422+
},
423+
},
424+
},
425+
},
426+
expectCapabilities: GroupControllerCapabilitySet{
427+
csi.GroupControllerServiceCapability_RPC_CREATE_DELETE_GET_VOLUME_GROUP_SNAPSHOT: true,
428+
},
429+
expectError: false,
430+
},
431+
{
432+
name: "gRPC error",
433+
output: nil,
434+
injectError: true,
435+
expectError: true,
436+
},
437+
{
438+
name: "empty capability",
439+
output: &csi.GroupControllerGetCapabilitiesResponse{
440+
Capabilities: []*csi.GroupControllerServiceCapability{
441+
{
442+
Type: nil,
443+
},
444+
},
445+
},
446+
expectCapabilities: GroupControllerCapabilitySet{},
447+
expectError: false,
448+
},
449+
{
450+
name: "no capabilities",
451+
output: &csi.GroupControllerGetCapabilitiesResponse{
452+
Capabilities: []*csi.GroupControllerServiceCapability{},
453+
},
454+
expectCapabilities: GroupControllerCapabilitySet{},
455+
expectError: false,
456+
},
457+
}
458+
459+
for _, test := range tests {
460+
t.Run(test.name, func(t *testing.T) {
461+
var injectedErr error
462+
if test.injectError {
463+
injectedErr = fmt.Errorf("mock error")
464+
}
465+
466+
tmp := tmpDir(t)
467+
defer os.RemoveAll(tmp)
468+
groupCtrl := &fakeGroupControllerServer{
469+
groupControllerGetCapabilitiesResponse: test.output,
470+
471+
// Make code compatible with gofmt 1.10.2 (used by pull-sig-storage-csi-lib-utils-stable)
472+
// and 1.11.1 (which will be used by new Prow job) via an extra blank line.
473+
err: injectedErr,
474+
}
475+
addr, stopServer := startServer(t, tmp, nil, nil, groupCtrl)
476+
defer func() {
477+
stopServer()
478+
}()
479+
480+
conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
481+
if err != nil {
482+
t.Fatalf("Failed to connect to CSI driver: %s", err)
483+
}
484+
485+
caps, err := GetGroupControllerCapabilities(context.Background(), conn)
486+
if test.expectError && err == nil {
487+
t.Errorf("Expected error, got none")
488+
}
489+
if !test.expectError && err != nil {
490+
t.Errorf("Got error: %v", err)
491+
}
492+
if !reflect.DeepEqual(test.expectCapabilities, caps) {
493+
t.Errorf("expected capabilities %+v, got %+v", test.expectCapabilities, caps)
494+
}
495+
})
496+
}
497+
}
498+
402499
func TestProbeForever(t *testing.T) {
403500
tests := []struct {
404501
name string
@@ -509,7 +606,7 @@ func TestProbeForever(t *testing.T) {
509606
identity := &fakeIdentityServer{
510607
probeCalls: test.probeCalls,
511608
}
512-
addr, stopServer := startServer(t, tmp, identity, nil)
609+
addr, stopServer := startServer(t, tmp, identity, nil, nil)
513610
defer func() {
514611
stopServer()
515612
}()
@@ -624,3 +721,26 @@ func (c *fakeControllerServer) ListSnapshots(context.Context, *csi.ListSnapshots
624721
func (c *fakeControllerServer) ControllerExpandVolume(context.Context, *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
625722
return nil, fmt.Errorf("unimplemented")
626723
}
724+
725+
type fakeGroupControllerServer struct {
726+
groupControllerGetCapabilitiesResponse *csi.GroupControllerGetCapabilitiesResponse
727+
err error
728+
}
729+
730+
var _ csi.GroupControllerServer = &fakeGroupControllerServer{}
731+
732+
func (c *fakeGroupControllerServer) GroupControllerGetCapabilities(context.Context, *csi.GroupControllerGetCapabilitiesRequest) (*csi.GroupControllerGetCapabilitiesResponse, error) {
733+
return c.groupControllerGetCapabilitiesResponse, c.err
734+
}
735+
736+
func (c *fakeGroupControllerServer) CreateVolumeGroupSnapshot(context.Context, *csi.CreateVolumeGroupSnapshotRequest) (*csi.CreateVolumeGroupSnapshotResponse, error) {
737+
return nil, fmt.Errorf("unimplemented")
738+
}
739+
740+
func (c *fakeGroupControllerServer) DeleteVolumeGroupSnapshot(context.Context, *csi.DeleteVolumeGroupSnapshotRequest) (*csi.DeleteVolumeGroupSnapshotResponse, error) {
741+
return nil, fmt.Errorf("unimplemented")
742+
}
743+
744+
func (c *fakeGroupControllerServer) GetVolumeGroupSnapshot(context.Context, *csi.GetVolumeGroupSnapshotRequest) (*csi.GetVolumeGroupSnapshotResponse, error) {
745+
return nil, fmt.Errorf("unimplemented")
746+
}

0 commit comments

Comments
 (0)