Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions container_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ func TestContainerRuntimeOptionsv2(t *testing.T) {
}
}

func TestContainerKillInitPidHost(t *testing.T) {
func initContainerAndCheckChildrenDieOnKill(t *testing.T, opts ...oci.SpecOpts) {
client, err := newClient(t, address)
if err != nil {
t.Fatal(err)
Expand All @@ -1059,12 +1059,12 @@ func TestContainerKillInitPidHost(t *testing.T) {
t.Fatal(err)
}

opts = append(opts, oci.WithImageConfig(image))
opts = append(opts, withProcessArgs("sh", "-c", "sleep 42; echo hi"))

container, err := client.NewContainer(ctx, id,
WithNewSnapshot(id, image),
WithNewSpec(oci.WithImageConfig(image),
withProcessArgs("sh", "-c", "sleep 42; echo hi"),
oci.WithHostNamespace(specs.PIDNamespace),
),
WithNewSpec(opts...),
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1111,6 +1111,14 @@ func TestContainerKillInitPidHost(t *testing.T) {
}
}

func TestContainerKillInitPidHost(t *testing.T) {
initContainerAndCheckChildrenDieOnKill(t, oci.WithHostNamespace(specs.PIDNamespace))
}

func TestContainerKillInitKillsChildWhenNotHostPid(t *testing.T) {
initContainerAndCheckChildrenDieOnKill(t)
}

func TestUserNamespaces(t *testing.T) {
t.Parallel()
t.Run("WritableRootFS", func(t *testing.T) { testUserNamespaces(t, false) })
Expand Down
1 change: 0 additions & 1 deletion runtime/v1/linux/proc/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ func (p *Init) setExited(status int) {
}

func (p *Init) delete(context context.Context) error {
p.KillAll(context)
p.wg.Wait()
err := p.runtime.Delete(context, p.id, nil)
// ignore errors if a runtime has already deleted the process
Expand Down
41 changes: 36 additions & 5 deletions runtime/v1/shim/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package shim

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
Expand All @@ -41,6 +43,7 @@ import (
runc "github.com/containerd/go-runc"
"github.com/containerd/typeurl"
ptypes "github.com/gogo/protobuf/types"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -507,13 +510,22 @@ func (s *Service) processExits() {
func (s *Service) checkProcesses(e runc.Exit) {
s.mu.Lock()
defer s.mu.Unlock()

shouldKillAll, err := shouldKillAllOnExit(s.bundle)
if err != nil {
log.G(s.context).WithError(err).Error("failed to check shouldKillAll")
}

for _, p := range s.processes {
if p.Pid() == e.Pid {
if ip, ok := p.(*proc.Init); ok {
// Ensure all children are killed
if err := ip.KillAll(s.context); err != nil {
log.G(s.context).WithError(err).WithField("id", ip.ID()).
Error("failed to kill init's children")

if shouldKillAll {
if ip, ok := p.(*proc.Init); ok {
// Ensure all children are killed
if err := ip.KillAll(s.context); err != nil {
log.G(s.context).WithError(err).WithField("id", ip.ID()).
Error("failed to kill init's children")
}
}
}
p.SetExited(e.Status)
Expand All @@ -529,6 +541,25 @@ func (s *Service) checkProcesses(e runc.Exit) {
}
}

func shouldKillAllOnExit(bundlePath string) (bool, error) {
var bundleSpec specs.Spec
bundleConfigContents, err := ioutil.ReadFile(filepath.Join(bundlePath, "config.json"))
if err != nil {
return false, err
}
json.Unmarshal(bundleConfigContents, &bundleSpec)

if bundleSpec.Linux != nil {
for _, ns := range bundleSpec.Linux.Namespaces {
if ns.Type == specs.PIDNamespace {
return false, nil
}
}
}

return true, nil
}

func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
39 changes: 34 additions & 5 deletions runtime/v2/runc/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package runc

import (
"context"
"encoding/json"
"io/ioutil"
"os"
"os/exec"
Expand All @@ -34,6 +35,7 @@ import (
"github.com/containerd/containerd/api/types/task"
"github.com/containerd/containerd/errdefs"
"github.com/containerd/containerd/events"
"github.com/containerd/containerd/log"
"github.com/containerd/containerd/mount"
"github.com/containerd/containerd/namespaces"
"github.com/containerd/containerd/runtime"
Expand All @@ -45,6 +47,7 @@ import (
runcC "github.com/containerd/go-runc"
"github.com/containerd/typeurl"
ptypes "github.com/gogo/protobuf/types"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -638,13 +641,20 @@ func (s *service) processExits() {
}

func (s *service) checkProcesses(e runcC.Exit) {
shouldKillAll, err := shouldKillAllOnExit(s.bundle)
if err != nil {
log.G(s.context).WithError(err).Error("failed to check shouldKillAll")
}

for _, p := range s.allProcesses() {
if p.Pid() == e.Pid {
if ip, ok := p.(*proc.Init); ok {
// Ensure all children are killed
if err := ip.KillAll(s.context); err != nil {
logrus.WithError(err).WithField("id", ip.ID()).
Error("failed to kill init's children")
if shouldKillAll {
if ip, ok := p.(*proc.Init); ok {
// Ensure all children are killed
if err := ip.KillAll(s.context); err != nil {
logrus.WithError(err).WithField("id", ip.ID()).
Error("failed to kill init's children")
}
}
}
p.SetExited(e.Status)
Expand All @@ -660,6 +670,25 @@ func (s *service) checkProcesses(e runcC.Exit) {
}
}

func shouldKillAllOnExit(bundlePath string) (bool, error) {
var bundleSpec specs.Spec
bundleConfigContents, err := ioutil.ReadFile(filepath.Join(bundlePath, "config.json"))
if err != nil {
return false, err
}
json.Unmarshal(bundleConfigContents, &bundleSpec)

if bundleSpec.Linux != nil {
for _, ns := range bundleSpec.Linux.Namespaces {
if ns.Type == specs.PIDNamespace {
return false, nil
}
}
}

return true, nil
}

func (s *service) allProcesses() (o []rproc.Process) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down