diff --git a/pipeline/backend/kubernetes/kubernetes.go b/pipeline/backend/kubernetes/kubernetes.go index 49448cede..3bcd79623 100644 --- a/pipeline/backend/kubernetes/kubernetes.go +++ b/pipeline/backend/kubernetes/kubernetes.go @@ -25,6 +25,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" backoff "github.com/cenkalti/backoff/v5" @@ -287,7 +288,8 @@ func (e *kube) WaitStep(ctx context.Context, step *types.Step, taskUUID string) log.Trace().Str("taskUUID", taskUUID).Msgf("waiting for pod: %s", podName) - finished := make(chan bool) + finished := make(chan struct{}) + var finishedOnce sync.Once podUpdated := func(_, newPod any) { pod, ok := newPod.(*v1.Pod) @@ -298,12 +300,12 @@ func (e *kube) WaitStep(ctx context.Context, step *types.Step, taskUUID string) if pod.Name == podName { if isImagePullBackOffState(pod) || isInvalidImageName(pod) { - finished <- true + finishedOnce.Do(func() { close(finished) }) } switch pod.Status.Phase { case v1.PodSucceeded, v1.PodFailed, v1.PodUnknown: - finished <- true + finishedOnce.Do(func() { close(finished) }) } } } @@ -321,8 +323,11 @@ func (e *kube) WaitStep(ctx context.Context, step *types.Step, taskUUID string) si.Start(stop) defer close(stop) - // TODO: Cancel on ctx.Done - <-finished + select { + case <-finished: + case <-ctx.Done(): + return nil, ctx.Err() + } pod, err := e.client.CoreV1().Pods(e.config.GetNamespace(step.OrgID)).Get(ctx, podName, meta_v1.GetOptions{}) if err != nil { @@ -363,7 +368,8 @@ func (e *kube) TailStep(ctx context.Context, step *types.Step, taskUUID string) log.Trace().Str("taskUUID", taskUUID).Msgf("tail logs of pod: %s", podName) - up := make(chan bool) + up := make(chan struct{}) + var upOnce sync.Once podUpdated := func(_, newPod any) { pod, ok := newPod.(*v1.Pod) @@ -374,11 +380,11 @@ func (e *kube) TailStep(ctx context.Context, step *types.Step, taskUUID string) if pod.Name == podName { if isImagePullBackOffState(pod) || isInvalidImageName(pod) { - up <- true + upOnce.Do(func() { close(up) }) } switch pod.Status.Phase { case v1.PodRunning, v1.PodSucceeded, v1.PodFailed: - up <- true + upOnce.Do(func() { close(up) }) } } } @@ -396,7 +402,11 @@ func (e *kube) TailStep(ctx context.Context, step *types.Step, taskUUID string) si.Start(stop) defer close(stop) - <-up + select { + case <-up: + case <-ctx.Done(): + return nil, ctx.Err() + } opts := &v1.PodLogOptions{ Follow: true, diff --git a/pipeline/backend/kubernetes/kubernetes_test.go b/pipeline/backend/kubernetes/kubernetes_test.go index e911bd247..2cfc0af1e 100644 --- a/pipeline/backend/kubernetes/kubernetes_test.go +++ b/pipeline/backend/kubernetes/kubernetes_test.go @@ -16,11 +16,16 @@ package kubernetes import ( "context" + "fmt" + "runtime" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/urfave/cli/v3" + v1 "k8s.io/api/core/v1" meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes/fake" @@ -175,3 +180,125 @@ func TestAffinityFromCliContext(t *testing.T) { err := cmd.Run(context.Background(), []string{"test"}) require.NoError(t, err) } + +func makeStep(uuid string) *types.Step { + return &types.Step{ + UUID: uuid, + Name: "step-" + uuid, + OrgID: 1, + } +} + +func makeEngine(client *fake.Clientset) *kube { + return &kube{ + client: client, + config: &config{ + Namespace: "test-ns", + }, + } +} + +func createPod( + t *testing.T, + client *fake.Clientset, + step *types.Step, + namespace string, +) string { + t.Helper() + podName, err := stepToPodName(step) + require.NoError(t, err) + + pod := &v1.Pod{ + ObjectMeta: meta_v1.ObjectMeta{ + Name: podName, + Namespace: namespace, + }, + Status: v1.PodStatus{ + Phase: v1.PodPending, + }, + } + _, err = client.CoreV1().Pods(namespace).Create( + context.Background(), pod, meta_v1.CreateOptions{}, + ) + require.NoError(t, err) + return podName +} + +func TestWaitStepReturnsOnContextCancel(t *testing.T) { + client := fake.NewClientset() + engine := makeEngine(client) + step := makeStep("ctx-cancel-01") + namespace := "test-ns" + + createPod(t, client, step, namespace) + + ctx, cancel := context.WithCancelCause(context.Background()) + + type result struct { + state *types.State + err error + } + ch := make(chan result, 1) + + go func() { + s, err := engine.WaitStep(ctx, step, "task-1") + ch <- result{s, err} + }() + + // Give the informer time to start and begin watching. + time.Sleep(200 * time.Millisecond) + + cancel(nil) + + select { + case r := <-ch: + assert.Nil(t, r.state) + assert.ErrorIs(t, r.err, context.Canceled) + case <-time.After(3 * time.Second): + t.Fatal("WaitStep did not return after context cancellation") + } +} + +func TestWaitStepNoGoroutineLeak(t *testing.T) { + client := fake.NewClientset() + engine := makeEngine(client) + namespace := "test-ns" + numSteps := 10 + + steps := make([]*types.Step, numSteps) + for i := range numSteps { + steps[i] = makeStep(fmt.Sprintf("leak-%02d", i)) + createPod(t, client, steps[i], namespace) + } + + runtime.GC() + time.Sleep(100 * time.Millisecond) + baselineGoroutines := runtime.NumGoroutine() + + var wg sync.WaitGroup + for i := range numSteps { + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel := context.WithCancelCause(context.Background()) + + go func() { + _, _ = engine.WaitStep(ctx, steps[i], fmt.Sprintf("task-%d", i)) + }() + + time.Sleep(200 * time.Millisecond) + cancel(nil) + }() + } + wg.Wait() + + time.Sleep(1 * time.Second) + + afterCancelGoroutines := runtime.NumGoroutine() + leaked := afterCancelGoroutines - baselineGoroutines + + assert.Less(t, leaked, numSteps, + "goroutines leaked after canceling %d WaitStep calls: got %d leaked", + numSteps, leaked) +} diff --git a/pipeline/runtime/executor.go b/pipeline/runtime/executor.go index edf1b0d19..9c1f94ad6 100644 --- a/pipeline/runtime/executor.go +++ b/pipeline/runtime/executor.go @@ -261,6 +261,9 @@ func (r *Runtime) exec(runnerCtx context.Context, step *backend.Step, setupWg *s waitState, err := r.engine.WaitStep(r.ctx, step, r.taskUUID) //nolint:contextcheck if err != nil { if errors.Is(err, context.Canceled) { + if waitState == nil { + waitState = &backend.State{} + } waitState.Error = pipeline_errors.ErrCancel } else { return nil, err