Skip to content

Commit 1ab42be

Browse files
committed
refactor: reduce duplicate code
Signed-off-by: Ye Sijun <[email protected]>
1 parent 279f406 commit 1ab42be

2 files changed

Lines changed: 136 additions & 36 deletions

File tree

oci/spec_opts.go

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -701,11 +701,8 @@ func WithUIDGID(uid, gid uint32) SpecOpts {
701701
func WithUserID(uid uint32) SpecOpts {
702702
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
703703
setProcess(s)
704-
if c.Snapshotter == "" && c.SnapshotKey == "" {
705-
if !isRootfsAbs(s.Root.Path) {
706-
return errors.New("rootfs absolute path is required")
707-
}
708-
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
704+
setUser := func(root string) error {
705+
user, err := UserFromPath(root, func(u user.User) bool {
709706
return u.Uid == int(uid)
710707
})
711708
if err != nil {
@@ -717,7 +714,12 @@ func WithUserID(uid uint32) SpecOpts {
717714
}
718715
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
719716
return nil
720-
717+
}
718+
if c.Snapshotter == "" && c.SnapshotKey == "" {
719+
if !isRootfsAbs(s.Root.Path) {
720+
return errors.New("rootfs absolute path is required")
721+
}
722+
return setUser(s.Root.Path)
721723
}
722724
if c.Snapshotter == "" {
723725
return errors.New("no snapshotter set for container")
@@ -732,20 +734,7 @@ func WithUserID(uid uint32) SpecOpts {
732734
}
733735

734736
mounts = tryReadonlyMounts(mounts)
735-
return mount.WithTempMount(ctx, mounts, func(root string) error {
736-
user, err := UserFromPath(root, func(u user.User) bool {
737-
return u.Uid == int(uid)
738-
})
739-
if err != nil {
740-
if os.IsNotExist(err) || err == ErrNoUsersFound {
741-
s.Process.User.UID, s.Process.User.GID = uid, 0
742-
return nil
743-
}
744-
return err
745-
}
746-
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
747-
return nil
748-
})
737+
return mount.WithTempMount(ctx, mounts, setUser)
749738
}
750739
}
751740

@@ -759,11 +748,8 @@ func WithUsername(username string) SpecOpts {
759748
return func(ctx context.Context, client Client, c *containers.Container, s *Spec) (err error) {
760749
setProcess(s)
761750
if s.Linux != nil {
762-
if c.Snapshotter == "" && c.SnapshotKey == "" {
763-
if !isRootfsAbs(s.Root.Path) {
764-
return errors.New("rootfs absolute path is required")
765-
}
766-
user, err := UserFromPath(s.Root.Path, func(u user.User) bool {
751+
setUser := func(root string) error {
752+
user, err := UserFromPath(root, func(u user.User) bool {
767753
return u.Name == username
768754
})
769755
if err != nil {
@@ -772,6 +758,12 @@ func WithUsername(username string) SpecOpts {
772758
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
773759
return nil
774760
}
761+
if c.Snapshotter == "" && c.SnapshotKey == "" {
762+
if !isRootfsAbs(s.Root.Path) {
763+
return errors.New("rootfs absolute path is required")
764+
}
765+
return setUser(s.Root.Path)
766+
}
775767
if c.Snapshotter == "" {
776768
return errors.New("no snapshotter set for container")
777769
}
@@ -785,16 +777,7 @@ func WithUsername(username string) SpecOpts {
785777
}
786778

787779
mounts = tryReadonlyMounts(mounts)
788-
return mount.WithTempMount(ctx, mounts, func(root string) error {
789-
user, err := UserFromPath(root, func(u user.User) bool {
790-
return u.Name == username
791-
})
792-
if err != nil {
793-
return err
794-
}
795-
s.Process.User.UID, s.Process.User.GID = uint32(user.Uid), uint32(user.Gid)
796-
return nil
797-
})
780+
return mount.WithTempMount(ctx, mounts, setUser)
798781
} else if s.Windows != nil {
799782
s.Process.User.Username = username
800783
} else {

oci/spec_opts_linux_test.go

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package oci
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"os"
2223
"path/filepath"
2324
"testing"
@@ -30,6 +31,123 @@ import (
3031
"golang.org/x/sys/unix"
3132
)
3233

34+
// nolint:gosec
35+
func TestWithUserID(t *testing.T) {
36+
t.Parallel()
37+
38+
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
39+
guest:x:405:100:guest:/dev/null:/sbin/nologin
40+
`
41+
td := t.TempDir()
42+
apply := fstest.Apply(
43+
fstest.CreateDir("/etc", 0777),
44+
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
45+
)
46+
if err := apply.Apply(td); err != nil {
47+
t.Fatalf("failed to apply: %v", err)
48+
}
49+
c := containers.Container{ID: t.Name()}
50+
testCases := []struct {
51+
userID uint32
52+
expectedUID uint32
53+
expectedGID uint32
54+
}{
55+
{
56+
userID: 0,
57+
expectedUID: 0,
58+
expectedGID: 0,
59+
},
60+
{
61+
userID: 405,
62+
expectedUID: 405,
63+
expectedGID: 100,
64+
},
65+
{
66+
userID: 1000,
67+
expectedUID: 1000,
68+
expectedGID: 0,
69+
},
70+
}
71+
for _, testCase := range testCases {
72+
t.Run(fmt.Sprintf("user %d", testCase.userID), func(t *testing.T) {
73+
t.Parallel()
74+
s := Spec{
75+
Version: specs.Version,
76+
Root: &specs.Root{
77+
Path: td,
78+
},
79+
Linux: &specs.Linux{},
80+
}
81+
err := WithUserID(testCase.userID)(context.Background(), nil, &c, &s)
82+
assert.NoError(t, err)
83+
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
84+
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
85+
})
86+
}
87+
}
88+
89+
// nolint:gosec
90+
func TestWithUsername(t *testing.T) {
91+
t.Parallel()
92+
93+
expectedPasswd := `root:x:0:0:root:/root:/bin/ash
94+
guest:x:405:100:guest:/dev/null:/sbin/nologin
95+
`
96+
td := t.TempDir()
97+
apply := fstest.Apply(
98+
fstest.CreateDir("/etc", 0777),
99+
fstest.CreateFile("/etc/passwd", []byte(expectedPasswd), 0777),
100+
)
101+
if err := apply.Apply(td); err != nil {
102+
t.Fatalf("failed to apply: %v", err)
103+
}
104+
c := containers.Container{ID: t.Name()}
105+
testCases := []struct {
106+
user string
107+
expectedUID uint32
108+
expectedGID uint32
109+
err string
110+
}{
111+
{
112+
user: "root",
113+
expectedUID: 0,
114+
expectedGID: 0,
115+
},
116+
{
117+
user: "guest",
118+
expectedUID: 405,
119+
expectedGID: 100,
120+
},
121+
{
122+
user: "1000",
123+
err: "no users found",
124+
},
125+
{
126+
user: "unknown",
127+
err: "no users found",
128+
},
129+
}
130+
for _, testCase := range testCases {
131+
t.Run(testCase.user, func(t *testing.T) {
132+
t.Parallel()
133+
s := Spec{
134+
Version: specs.Version,
135+
Root: &specs.Root{
136+
Path: td,
137+
},
138+
Linux: &specs.Linux{},
139+
}
140+
err := WithUsername(testCase.user)(context.Background(), nil, &c, &s)
141+
if err != nil {
142+
assert.EqualError(t, err, testCase.err)
143+
}
144+
assert.Equal(t, testCase.expectedUID, s.Process.User.UID)
145+
assert.Equal(t, testCase.expectedGID, s.Process.User.GID)
146+
})
147+
}
148+
149+
}
150+
33151
// nolint:gosec
34152
func TestWithAdditionalGIDs(t *testing.T) {
35153
t.Parallel()
@@ -54,7 +172,6 @@ sys:x:3:root,bin,adm
54172
c := containers.Container{ID: t.Name()}
55173

56174
testCases := []struct {
57-
name string
58175
user string
59176
expected []uint32
60177
}{

0 commit comments

Comments
 (0)