Skip to content

Commit df34eef

Browse files
authored
Merge pull request #2330 from crosbymichael/hpc
Add nvidia gpu support
2 parents 00d4910 + b949697 commit df34eef

File tree

6 files changed

+390
-0
lines changed

6 files changed

+390
-0
lines changed

cmd/containerd/command/main.go

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ func App() *cli.App {
8888
app.Commands = []cli.Command{
8989
configCommand,
9090
publishCommand,
91+
ociHook,
9192
}
9293
app.Action = func(context *cli.Context) error {
9394
var (

cmd/containerd/command/oci-hook.go

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
Copyright The containerd Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package command
18+
19+
import (
20+
"bytes"
21+
"encoding/json"
22+
"io"
23+
"os"
24+
"path/filepath"
25+
"syscall"
26+
"text/template"
27+
28+
specs "github.com/opencontainers/runtime-spec/specs-go"
29+
"github.com/urfave/cli"
30+
)
31+
32+
var ociHook = cli.Command{
33+
Name: "oci-hook",
34+
Usage: "provides a base for OCI runtime hooks to allow arguments to be injected.",
35+
Action: func(context *cli.Context) error {
36+
state, err := loadHookState(os.Stdin)
37+
if err != nil {
38+
return err
39+
}
40+
var (
41+
ctx = newTemplateContext(state)
42+
args = []string(context.Args())
43+
env = os.Environ()
44+
)
45+
if err := newList(&args).render(ctx); err != nil {
46+
return err
47+
}
48+
if err := newList(&env).render(ctx); err != nil {
49+
return err
50+
}
51+
return syscall.Exec(args[0], args, env)
52+
},
53+
}
54+
55+
func loadHookState(r io.Reader) (*specs.State, error) {
56+
var s specs.State
57+
if err := json.NewDecoder(r).Decode(&s); err != nil {
58+
return nil, err
59+
}
60+
return &s, nil
61+
}
62+
63+
func newTemplateContext(state *specs.State) *templateContext {
64+
t := &templateContext{
65+
state: state,
66+
}
67+
t.funcs = template.FuncMap{
68+
"id": t.id,
69+
"bundle": t.bundle,
70+
"rootfs": t.rootfs,
71+
"pid": t.pid,
72+
"annotation": t.annotation,
73+
"status": t.status,
74+
}
75+
return t
76+
}
77+
78+
type templateContext struct {
79+
state *specs.State
80+
funcs template.FuncMap
81+
}
82+
83+
func (t *templateContext) id() string {
84+
return t.state.ID
85+
}
86+
87+
func (t *templateContext) bundle() string {
88+
return t.state.Bundle
89+
}
90+
91+
func (t *templateContext) rootfs() string {
92+
return filepath.Join(t.state.Bundle, "rootfs")
93+
}
94+
95+
func (t *templateContext) pid() int {
96+
return t.state.Pid
97+
}
98+
99+
func (t *templateContext) annotation(k string) string {
100+
return t.state.Annotations[k]
101+
}
102+
103+
func (t *templateContext) status() string {
104+
return t.state.Status
105+
}
106+
107+
func render(ctx *templateContext, source string, out io.Writer) error {
108+
t, err := template.New("oci-hook").Funcs(ctx.funcs).Parse(source)
109+
if err != nil {
110+
return err
111+
}
112+
return t.Execute(out, ctx)
113+
}
114+
115+
func newList(l *[]string) *templateList {
116+
return &templateList{
117+
l: l,
118+
}
119+
}
120+
121+
type templateList struct {
122+
l *[]string
123+
}
124+
125+
func (l *templateList) render(ctx *templateContext) error {
126+
buf := bytes.NewBuffer(nil)
127+
for i, s := range *l.l {
128+
buf.Reset()
129+
if err := render(ctx, s, buf); err != nil {
130+
return err
131+
}
132+
(*l.l)[i] = buf.String()
133+
}
134+
buf.Reset()
135+
return nil
136+
}

cmd/ctr/commands/run/run.go

+4
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ var ContainerFlags = []cli.Flag{
9393
Name: "pid-file",
9494
Usage: "file path to write the task's pid",
9595
},
96+
cli.IntFlag{
97+
Name: "gpus",
98+
Usage: "add gpus to the container",
99+
},
96100
}
97101

98102
func loadSpec(path string, s *specs.Spec) error {

cmd/ctr/commands/run/run_unix.go

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424

2525
"github.com/containerd/containerd"
2626
"github.com/containerd/containerd/cmd/ctr/commands"
27+
"github.com/containerd/containerd/contrib/nvidia"
2728
"github.com/containerd/containerd/oci"
2829
specs "github.com/opencontainers/runtime-spec/specs-go"
2930
"github.com/pkg/errors"
@@ -123,6 +124,9 @@ func NewContainer(ctx gocontext.Context, client *containerd.Client, context *cli
123124
Path: parts[1],
124125
}))
125126
}
127+
if context.IsSet("gpus") {
128+
opts = append(opts, nvidia.WithGPUs(nvidia.WithDevices(context.Int("gpus")), nvidia.WithAllCapabilities))
129+
}
126130
if context.IsSet("config") {
127131
var s specs.Spec
128132
if err := loadSpec(context.String("config"), &s); err != nil {

container_test.go

+60
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"io"
2323
"io/ioutil"
2424
"os"
25+
"os/exec"
2526
"runtime"
2627
"strings"
2728
"syscall"
@@ -34,6 +35,7 @@ import (
3435
"github.com/containerd/containerd/oci"
3536
_ "github.com/containerd/containerd/runtime"
3637
"github.com/containerd/typeurl"
38+
specs "github.com/opencontainers/runtime-spec/specs-go"
3739

3840
"github.com/containerd/containerd/errdefs"
3941
"github.com/containerd/containerd/windows/hcsshimtypes"
@@ -1469,3 +1471,61 @@ func TestContainerLabels(t *testing.T) {
14691471
t.Fatalf("expected label \"test\" to be \"no\"")
14701472
}
14711473
}
1474+
1475+
func TestContainerHook(t *testing.T) {
1476+
t.Parallel()
1477+
1478+
client, err := newClient(t, address)
1479+
if err != nil {
1480+
t.Fatal(err)
1481+
}
1482+
defer client.Close()
1483+
1484+
var (
1485+
image Image
1486+
ctx, cancel = testContext()
1487+
id = t.Name()
1488+
)
1489+
defer cancel()
1490+
1491+
image, err = client.GetImage(ctx, testImage)
1492+
if err != nil {
1493+
t.Fatal(err)
1494+
}
1495+
hook := func(_ context.Context, _ oci.Client, _ *containers.Container, s *specs.Spec) error {
1496+
if s.Hooks == nil {
1497+
s.Hooks = &specs.Hooks{}
1498+
}
1499+
path, err := exec.LookPath("containerd")
1500+
if err != nil {
1501+
return err
1502+
}
1503+
psPath, err := exec.LookPath("ps")
1504+
if err != nil {
1505+
return err
1506+
}
1507+
s.Hooks.Prestart = []specs.Hook{
1508+
{
1509+
Path: path,
1510+
Args: []string{
1511+
"containerd",
1512+
"oci-hook", "--",
1513+
psPath, "--pid", "{{pid}}",
1514+
},
1515+
Env: os.Environ(),
1516+
},
1517+
}
1518+
return nil
1519+
}
1520+
container, err := client.NewContainer(ctx, id, WithNewSpec(oci.WithImageConfig(image), hook), WithNewSnapshot(id, image))
1521+
if err != nil {
1522+
t.Fatal(err)
1523+
}
1524+
defer container.Delete(ctx, WithSnapshotCleanup)
1525+
1526+
task, err := container.NewTask(ctx, empty())
1527+
if err != nil {
1528+
t.Fatal(err)
1529+
}
1530+
defer task.Delete(ctx, WithProcessKill)
1531+
}

0 commit comments

Comments
 (0)