Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions integration/client/restart_monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@ package client
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"syscall"
"testing"
"time"

. "github.com/containerd/containerd"
"github.com/containerd/containerd/containers"
eventtypes "github.com/containerd/containerd/api/events"
"github.com/containerd/containerd/oci"
"github.com/containerd/containerd/pkg/testutil"
"github.com/containerd/containerd/runtime/restart"
srvconfig "github.com/containerd/containerd/services/server/config"
"github.com/containerd/containerd/sys"
"github.com/containerd/typeurl"
exec "golang.org/x/sys/execabs"
)

Expand Down Expand Up @@ -148,7 +152,7 @@ version = 2
oci.WithImageConfig(image),
longCommand,
),
withRestartStatus(Running),
restart.WithStatus(Running),
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -229,14 +233,115 @@ version = 2
t.Logf("%v: the task was restarted since %v", time.Now(), lastCheck)
}

// withRestartStatus is a copy of "github.com/containerd/containerd/runtime/restart".WithStatus.
// This copy is needed because `go test` refuses circular imports.
func withRestartStatus(status ProcessStatus) func(context.Context, *Client, *containers.Container) error {
return func(_ context.Context, _ *Client, c *containers.Container) error {
if c.Labels == nil {
c.Labels = make(map[string]string)
func TestRestartMonitorWithOnFailurePolicy(t *testing.T) {
const (
interval = 5 * time.Second
)
configTOML := fmt.Sprintf(`
version = 2
[plugins]
[plugins."io.containerd.internal.v1.restart"]
interval = "%s"
`, interval.String())
client, _, cleanup := newDaemonWithConfig(t, configTOML)
defer cleanup()

var (
ctx, cancel = testContext(t)
id = t.Name()
)
defer cancel()

image, err := client.Pull(ctx, testImage, WithPullUnpack)
if err != nil {
t.Fatal(err)
}

policy, _ := restart.NewPolicy("on-failure:1")
container, err := client.NewContainer(ctx, id,
WithNewSnapshot(id, image),
WithNewSpec(
oci.WithImageConfig(image),
// always exited with 1
withExitStatus(1),
),
restart.WithStatus(Running),
restart.WithPolicy(policy),
)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := container.Delete(ctx, WithSnapshotCleanup); err != nil {
t.Logf("failed to delete container: %v", err)
}
c.Labels["containerd.io/restart.status"] = string(status)
return nil
}()

task, err := container.NewTask(ctx, empty())
if err != nil {
t.Fatal(err)
}
defer func() {
if _, err := task.Delete(ctx, WithProcessKill); err != nil {
t.Logf("failed to delete task: %v", err)
}
}()

if err := task.Start(ctx); err != nil {
t.Fatal(err)
}

statusCh, err := task.Wait(ctx)
if err != nil {
t.Fatal(err)
}

eventCh, eventErrCh := client.Subscribe(ctx, `topic=="/tasks/create"`)

select {
case <-statusCh:
case <-time.After(30 * time.Second):
t.Fatal("should receive exit event in time")
}

select {
case e := <-eventCh:
cid, err := convertTaskCreateEvent(e.Event)
if err != nil {
t.Fatal(err)
}
if cid != id {
t.Fatalf("expected task id = %s, but got %s", id, cid)
}
case err := <-eventErrCh:
t.Fatalf("unexpected error from event channel: %v", err)
case <-time.After(1 * time.Minute):
t.Fatal("should receive create event in time")
}

labels, err := container.Labels(ctx)
if err != nil {
t.Fatal(err)
}
restartCount, _ := strconv.Atoi(labels[restart.CountLabel])
if restartCount != 1 {
t.Fatalf("expected restart count to be 1, got %d", restartCount)
}
}

func convertTaskCreateEvent(e typeurl.Any) (string, error) {
id := ""

evt, err := typeurl.UnmarshalAny(e)
if err != nil {
return "", fmt.Errorf("failed to unmarshalany: %w", err)
}

switch e := evt.(type) {
case *eventtypes.TaskCreate:
id = e.ContainerID
default:
return "", errors.New("unsupported event")
}
return id, nil
}
12 changes: 12 additions & 0 deletions runtime/restart/monitor/change.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"context"
"fmt"
"net/url"
"strconv"
"syscall"

"github.com/containerd/containerd"
"github.com/containerd/containerd/cio"
"github.com/containerd/containerd/runtime/restart"
"github.com/sirupsen/logrus"
)

Expand All @@ -38,6 +40,7 @@ func (s *stopChange) apply(ctx context.Context, client *containerd.Client) error
type startChange struct {
container containerd.Container
logURI string
count int

// Deprecated(in release 1.5): but recognized now, prefer to use logURI
logPath string
Expand All @@ -61,6 +64,15 @@ func (s *startChange) apply(ctx context.Context, client *containerd.Client) erro
s.logPath, s.logURI)
}

if s.count > 0 {
labels := map[string]string{
restart.CountLabel: strconv.Itoa(s.count),
}
opt := containerd.WithAdditionalContainerLabels(labels)
if err := s.container.Update(ctx, containerd.UpdateContainerOpts(opt)); err != nil {
return err
}
}
killTask(ctx, s.container)
task, err := s.container.NewTask(ctx, log)
if err != nil {
Expand Down
30 changes: 17 additions & 13 deletions runtime/restart/monitor/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package monitor
import (
"context"
"fmt"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -72,6 +73,7 @@ func init() {
},
},
InitFn: func(ic *plugin.InitContext) (interface{}, error) {
ic.Meta.Capabilities = []string{"no", "always", "on-failure", "unless-stopped"}
opts, err := getServicesOpts(ic)
if err != nil {
return nil, err
Expand Down Expand Up @@ -213,15 +215,29 @@ func (m *monitor) monitor(ctx context.Context) ([]change, error) {
return nil, err
}
desiredStatus := containerd.ProcessStatus(labels[restart.StatusLabel])
if m.isSameStatus(ctx, desiredStatus, c) {
task, err := c.Task(ctx, nil)
if err != nil && desiredStatus == containerd.Stopped {
continue
}
status, err := task.Status(ctx)
if err != nil && desiredStatus == containerd.Stopped {
continue
}
if desiredStatus == status.Status {
continue
}

switch desiredStatus {
case containerd.Running:
if !restart.Reconcile(status, labels) {
continue
}
restartCount, _ := strconv.Atoi(labels[restart.CountLabel])
changes = append(changes, &startChange{
container: c,
logPath: labels[restart.LogPathLabel],
logURI: labels[restart.LogURILabel],
count: restartCount + 1,
})
case containerd.Stopped:
changes = append(changes, &stopChange{
Expand All @@ -231,15 +247,3 @@ func (m *monitor) monitor(ctx context.Context) ([]change, error) {
}
return changes, nil
}

func (m *monitor) isSameStatus(ctx context.Context, desired containerd.ProcessStatus, container containerd.Container) bool {
task, err := container.Task(ctx, nil)
if err != nil {
return desired == containerd.Stopped
}
state, err := task.Status(ctx)
if err != nil {
return desired == containerd.Stopped
}
return desired == state.Status
}
Loading