Skip to content

Commit 16d52de

Browse files
junnplusAkihiroSuda
authored andcommitted
refactor: reduce duplicate code
Signed-off-by: Ye Sijun <[email protected]> (cherry picked from commit 1ab42be) Signed-off-by: Akihiro Suda <[email protected]>
1 parent b45e302 commit 16d52de

File tree

2 files changed

+136
-36
lines changed

2 files changed

+136
-36
lines changed

oci/spec_opts.go

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -629,11 +629,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts {
629629
func WithUserID(uid uint32) SpecOpts {
630630
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
631631
setProcess(s)
632-
if c.Snapshotter == "" && c.SnapshotKey == "" {
633-
if !isRootfsAbs(s.Root.Path) {
634-
return errors.New("rootfs absolute path is required")
635-
}
636-
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
632+
setUser := func(root string) error {
633+
user, err := UserFromPath(root, func(u user.User) bool {
637634
return u.Uid == int(uid)
638635
})
639636
if err != nil {
@@ -645,7 +642,12 @@ func WithUserID(uid uint32) SpecOpts {
645642
}
646643
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
647644
return nil
648-
645+
}
646+
if c.Snapshotter == "" && c.SnapshotKey == "" {
647+
if !isRootfsAbs(s.Root.Path) {
648+
return errors.New("rootfs absolute path is required")
649+
}
650+
return setUser(s.Root.Path)
649651
}
650652
if c.Snapshotter == "" {
651653
return errors.New("no snapshotter set for container")
@@ -660,20 +662,7 @@ func WithUserID(uid uint32) SpecOpts {
660662
}
661663

662664
mounts = tryReadonlyMounts(mounts)
663-
return mount.WithTempMount(ctx, mounts, func(root string) error {
664-
user, err := UserFromPath(root, func(u user.User) bool {
665-
return u.Uid == int(uid)
666-
})
667-
if err != nil {
668-
if os.IsNotExist(err) || err == ErrNoUsersFound {
669-
s.Process.User.UID, s.Process.User.GID = uid, 0
670-
return nil
671-
}
672-
return err
673-
}
674-
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
675-
return nil
676-
})
665+
return mount.WithTempMount(ctx, mounts, setUser)
677666
}
678667
}
679668

@@ -687,11 +676,8 @@ func WithUsername(username string) SpecOpts {
687676
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
688677
setProcess(s)
689678
if s.Linux != nil {
690-
if c.Snapshotter == "" && c.SnapshotKey == "" {
691-
if !isRootfsAbs(s.Root.Path) {
692-
return errors.New("rootfs absolute path is required")
693-
}
694-
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
679+
setUser := func(root string) error {
680+
user, err := UserFromPath(root, func(u user.User) bool {
695681
return u.Name == username
696682
})
697683
if err != nil {
@@ -700,6 +686,12 @@ func WithUsername(username string) SpecOpts {
700686
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
701687
return nil
702688
}
689+
if c.Snapshotter == "" && c.SnapshotKey == "" {
690+
if !isRootfsAbs(s.Root.Path) {
691+
return errors.New("rootfs absolute path is required")
692+
}
693+
return setUser(s.Root.Path)
694+
}
703695
if c.Snapshotter == "" {
704696
return errors.New("no snapshotter set for container")
705697
}
@@ -713,16 +705,7 @@ func WithUsername(username string) SpecOpts {
713705
}
714706

715707
mounts = tryReadonlyMounts(mounts)
716-
return mount.WithTempMount(ctx, mounts, func(root string) error {
717-
user, err := UserFromPath(root, func(u user.User) bool {
718-
return u.Name == username
719-
})
720-
if err != nil {
721-
return err
722-
}
723-
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
724-
return nil
725-
})
708+
return mount.WithTempMount(ctx, mounts, setUser)
726709
} else if s.Windows != nil {
727710
s.Process.User.Username = username
728711
} else {

oci/spec_opts_linux_test.go

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package oci
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"os"
2223
"path/filepath"
2324
"testing"
@@ -30,6 +31,123 @@ import (
3031
"golang.org/x/sys/unix"
3132
)
3233

34+
// nolint:gosec
35+
func TestWithUserID(t *testing.T) {
36+
t.Parallel()
37+
38+
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
39+
guest:x:405:100:guest:/dev/null:/sbin/nologin
40+
`
41+
td := t.TempDir()
42+
apply := fstest.Apply(
43+
fstest.CreateDir("/etc", 0777),
44+
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
45+
)
46+
if err := apply.Apply(td); err != nil {
47+
t.Fatalf("failed to apply: %v", err)
48+
}
49+
c := containers.Container{ID: t.Name()}
50+
testCases := []struct {
51+
userID uint32
52+
expectedUID uint32
53+
expectedGID uint32
54+
}{
55+
{
56+
userID: 0,
57+
expectedUID: 0,
58+
expectedGID: 0,
59+
},
60+
{
61+
userID: 405,
62+
expectedUID: 405,
63+
expectedGID: 100,
64+
},
65+
{
66+
userID: 1000,
67+
expectedUID: 1000,
68+
expectedGID: 0,
69+
},
70+
}
71+
for _, testCase := range testCases {
72+
t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) {
73+
t.Parallel()
74+
s := Spec{
75+
Version: specs.Version,
76+
Root: &specs.Root{
77+
Path: td,
78+
},
79+
Linux: &specs.Linux{},
80+
}
81+
err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s)
82+
assert.NoError(t, err)
83+
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
84+
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
85+
})
86+
}
87+
}
88+
89+
// nolint:gosec
90+
func TestWithUsername(t *testing.T) {
91+
t.Parallel()
92+
93+
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
94+
guest:x:405:100:guest:/dev/null:/sbin/nologin
95+
`
96+
td := t.TempDir()
97+
apply := fstest.Apply(
98+
fstest.CreateDir("/etc", 0777),
99+
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
100+
)
101+
if err := apply.Apply(td); err != nil {
102+
t.Fatalf("failed to apply: %v", err)
103+
}
104+
c := containers.Container{ID: t.Name()}
105+
testCases := []struct {
106+
user string
107+
expectedUID uint32
108+
expectedGID uint32
109+
err string
110+
}{
111+
{
112+
user: "root",
113+
expectedUID: 0,
114+
expectedGID: 0,
115+
},
116+
{
117+
user: "guest",
118+
expectedUID: 405,
119+
expectedGID: 100,
120+
},
121+
{
122+
user: "1000",
123+
err: "no users found",
124+
},
125+
{
126+
user: "unknown",
127+
err: "no users found",
128+
},
129+
}
130+
for _, testCase := range testCases {
131+
t.Run(testCase.user, func(t *testing.T) {
132+
t.Parallel()
133+
s := Spec{
134+
Version: specs.Version,
135+
Root: &specs.Root{
136+
Path: td,
137+
},
138+
Linux: &specs.Linux{},
139+
}
140+
err := WithUsername(testCase.user)(context.Background(), nil, &c, &s)
141+
if err != nil {
142+
assert.EqualError(t, err, testCase.err)
143+
}
144+
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
145+
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
146+
})
147+
}
148+
149+
}
150+
33151
// nolint:gosec
34152
func TestWithAdditionalGIDs(t *testing.T) {
35153
t.Parallel()
@@ -54,7 +172,6 @@ sys:x:3:root,bin,adm
54172
c := containers.Container{ID: t.Name()}
55173

56174
testCases := []struct {
57-
name string
58175
user string
59176
expected []uint32
60177
}{

0 commit comments

Comments
 (0)