Merge branch 'origin/main' into 'next-release/main'

This commit is contained in:
oauth
2026-03-02 20:16:55 +00:00
3 changed files with 149 additions and 9 deletions

View File

@@ -25,6 +25,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
backoff "github.com/cenkalti/backoff/v5"
@@ -287,7 +288,8 @@ func (e *kube) WaitStep(ctx context.Context, step *types.Step, taskUUID string)
log.Trace().Str("taskUUID", taskUUID).Msgf("waiting for pod: %s", podName)
finished := make(chan bool)
finished := make(chan struct{})
var finishedOnce sync.Once
podUpdated := func(_, newPod any) {
pod, ok := newPod.(*v1.Pod)
@@ -298,12 +300,12 @@ func (e *kube) WaitStep(ctx context.Context, step *types.Step, taskUUID string)
if pod.Name == podName {
if isImagePullBackOffState(pod) || isInvalidImageName(pod) {
finished <- true
finishedOnce.Do(func() { close(finished) })
}
switch pod.Status.Phase {
case v1.PodSucceeded, v1.PodFailed, v1.PodUnknown:
finished <- true
finishedOnce.Do(func() { close(finished) })
}
}
}
@@ -321,8 +323,11 @@ func (e *kube) WaitStep(ctx context.Context, step *types.Step, taskUUID string)
si.Start(stop)
defer close(stop)
// TODO: Cancel on ctx.Done
<-finished
select {
case <-finished:
case <-ctx.Done():
return nil, ctx.Err()
}
pod, err := e.client.CoreV1().Pods(e.config.GetNamespace(step.OrgID)).Get(ctx, podName, meta_v1.GetOptions{})
if err != nil {
@@ -363,7 +368,8 @@ func (e *kube) TailStep(ctx context.Context, step *types.Step, taskUUID string)
log.Trace().Str("taskUUID", taskUUID).Msgf("tail logs of pod: %s", podName)
up := make(chan bool)
up := make(chan struct{})
var upOnce sync.Once
podUpdated := func(_, newPod any) {
pod, ok := newPod.(*v1.Pod)
@@ -374,11 +380,11 @@ func (e *kube) TailStep(ctx context.Context, step *types.Step, taskUUID string)
if pod.Name == podName {
if isImagePullBackOffState(pod) || isInvalidImageName(pod) {
up <- true
upOnce.Do(func() { close(up) })
}
switch pod.Status.Phase {
case v1.PodRunning, v1.PodSucceeded, v1.PodFailed:
up <- true
upOnce.Do(func() { close(up) })
}
}
}
@@ -396,7 +402,11 @@ func (e *kube) TailStep(ctx context.Context, step *types.Step, taskUUID string)
si.Start(stop)
defer close(stop)
<-up
select {
case <-up:
case <-ctx.Done():
return nil, ctx.Err()
}
opts := &v1.PodLogOptions{
Follow: true,

View File

@@ -16,11 +16,16 @@ package kubernetes
import (
"context"
"fmt"
"runtime"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v3"
v1 "k8s.io/api/core/v1"
meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"
@@ -175,3 +180,125 @@ func TestAffinityFromCliContext(t *testing.T) {
err := cmd.Run(context.Background(), []string{"test"})
require.NoError(t, err)
}
func makeStep(uuid string) *types.Step {
return &types.Step{
UUID: uuid,
Name: "step-" + uuid,
OrgID: 1,
}
}
func makeEngine(client *fake.Clientset) *kube {
return &kube{
client: client,
config: &config{
Namespace: "test-ns",
},
}
}
func createPod(
t *testing.T,
client *fake.Clientset,
step *types.Step,
namespace string,
) string {
t.Helper()
podName, err := stepToPodName(step)
require.NoError(t, err)
pod := &v1.Pod{
ObjectMeta: meta_v1.ObjectMeta{
Name: podName,
Namespace: namespace,
},
Status: v1.PodStatus{
Phase: v1.PodPending,
},
}
_, err = client.CoreV1().Pods(namespace).Create(
context.Background(), pod, meta_v1.CreateOptions{},
)
require.NoError(t, err)
return podName
}
func TestWaitStepReturnsOnContextCancel(t *testing.T) {
client := fake.NewClientset()
engine := makeEngine(client)
step := makeStep("ctx-cancel-01")
namespace := "test-ns"
createPod(t, client, step, namespace)
ctx, cancel := context.WithCancelCause(context.Background())
type result struct {
state *types.State
err error
}
ch := make(chan result, 1)
go func() {
s, err := engine.WaitStep(ctx, step, "task-1")
ch <- result{s, err}
}()
// Give the informer time to start and begin watching.
time.Sleep(200 * time.Millisecond)
cancel(nil)
select {
case r := <-ch:
assert.Nil(t, r.state)
assert.ErrorIs(t, r.err, context.Canceled)
case <-time.After(3 * time.Second):
t.Fatal("WaitStep did not return after context cancellation")
}
}
func TestWaitStepNoGoroutineLeak(t *testing.T) {
client := fake.NewClientset()
engine := makeEngine(client)
namespace := "test-ns"
numSteps := 10
steps := make([]*types.Step, numSteps)
for i := range numSteps {
steps[i] = makeStep(fmt.Sprintf("leak-%02d", i))
createPod(t, client, steps[i], namespace)
}
runtime.GC()
time.Sleep(100 * time.Millisecond)
baselineGoroutines := runtime.NumGoroutine()
var wg sync.WaitGroup
for i := range numSteps {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithCancelCause(context.Background())
go func() {
_, _ = engine.WaitStep(ctx, steps[i], fmt.Sprintf("task-%d", i))
}()
time.Sleep(200 * time.Millisecond)
cancel(nil)
}()
}
wg.Wait()
time.Sleep(1 * time.Second)
afterCancelGoroutines := runtime.NumGoroutine()
leaked := afterCancelGoroutines - baselineGoroutines
assert.Less(t, leaked, numSteps,
"goroutines leaked after canceling %d WaitStep calls: got %d leaked",
numSteps, leaked)
}

View File

@@ -261,6 +261,9 @@ func (r *Runtime) exec(runnerCtx context.Context, step *backend.Step, setupWg *s
waitState, err := r.engine.WaitStep(r.ctx, step, r.taskUUID) //nolint:contextcheck
if err != nil {
if errors.Is(err, context.Canceled) {
if waitState == nil {
waitState = &backend.State{}
}
waitState.Error = pipeline_errors.ErrCancel
} else {
return nil, err