Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/controller/machine/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func AddWithActuatorOpts(mgr manager.Manager, actuator Actuator, opts controller
}

if err := addWithOpts(mgr, controller.Options{
MaxConcurrentReconciles: opts.MaxConcurrentReconciles,
Reconciler: newDrainController(mgr),
RateLimiter: newDrainRateLimiter(),
}, "machine-drain-controller"); err != nil {
Expand Down
42 changes: 34 additions & 8 deletions pkg/controller/machine/drain_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package machine
import (
"context"
"fmt"
"sync"
"time"

"golang.org/x/time/rate"
Expand Down Expand Up @@ -38,6 +39,13 @@ type machineDrainController struct {
scheme *runtime.Scheme

eventRecorder record.EventRecorder

// controlPlaneDrainLock ensures only one control plane node can pass the
// isDrainAllowed check and be cordoned at a time. Without this, concurrent
// reconciles of different control plane machines can both pass
// isDrainAllowed before either cordons, leading to simultaneous control
// plane node drains (TOCTOU race).
controlPlaneDrainLock sync.Mutex
}

// newDrainController returns a new reconcile.Reconciler for machine-drain-controller
Expand Down Expand Up @@ -135,10 +143,6 @@ func (d *machineDrainController) drainNode(ctx context.Context, machine *machine
return fmt.Errorf("unable to get node %q: %v", machine.Status.NodeRef.Name, err)
}

if err := d.isDrainAllowed(ctx, node); err != nil {
return fmt.Errorf("drain not permitted: %w", err)
}

drainer := &drain.Helper{
Ctx: ctx,
Client: kubeClient,
Expand Down Expand Up @@ -170,10 +174,8 @@ func (d *machineDrainController) drainNode(ctx context.Context, machine *machine
drainer.GracePeriodSeconds = 1
}

if err := drain.RunCordonOrUncordon(drainer, node, true); err != nil {
// Can't cordon a node
klog.Warningf("cordon failed for node %q: %v", node.Name, err)
return &RequeueAfterError{RequeueAfter: 20 * time.Second}
if err := d.cordonNode(ctx, drainer, node); err != nil {
return err
}

if err := drain.RunNodeDrain(drainer, node.Name); err != nil {
Expand All @@ -193,6 +195,30 @@ func (d *machineDrainController) drainNode(ctx context.Context, machine *machine
return nil
}

// cordonNode checks whether draining is allowed and cordons the node.
// For uncordoned control plane nodes, the check and cordon are held under
// controlPlaneDrainLock so that only one CP node at a time can pass the
// isDrainAllowed check and get cordoned, preventing a TOCTOU race where
// concurrent reconciles would each see no other CP node cordoned and then
// all proceed to drain simultaneously.
func (d *machineDrainController) cordonNode(ctx context.Context, drainer *drain.Helper, node *corev1.Node) error {
if isControlPlaneNode(*node) && !node.Spec.Unschedulable {
d.controlPlaneDrainLock.Lock()
defer d.controlPlaneDrainLock.Unlock()
}

if err := d.isDrainAllowed(ctx, node); err != nil {
return fmt.Errorf("drain not permitted: %w", err)
}

if err := drain.RunCordonOrUncordon(drainer, node, true); err != nil {
klog.Warningf("cordon failed for node %q: %v", node.Name, err)
return &RequeueAfterError{RequeueAfter: 20 * time.Second}
}

return nil
}

// isDrainAllowed checks whether the drain is permitted at this time.
// It checks the following:
// - Is the node cordoned, if so allow draining to complete any previous attempt to drain.
Expand Down
92 changes: 92 additions & 0 deletions pkg/controller/machine/drain_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package machine

import (
"context"
"sync"
"testing"
"time"

Expand All @@ -11,8 +12,11 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
kubefake "k8s.io/client-go/kubernetes/fake"
"k8s.io/client-go/kubernetes/scheme"
clienttesting "k8s.io/client-go/testing"
"k8s.io/client-go/tools/record"
"k8s.io/kubectl/pkg/drain"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
Expand Down Expand Up @@ -345,3 +349,91 @@ func controlPlaneLabel(n *corev1.Node) {
func masterLabel(n *corev1.Node) {
n.GetLabels()[nodeMasterLabel] = ""
}

func TestCordonNodeSerializesCPDrains(t *testing.T) {
g := NewGomegaWithT(t)

cpNode1 := newNode("cp-node-1", controlPlaneLabel)
cpNode2 := newNode("cp-node-2", controlPlaneLabel)
workerNode := newNode("worker-1")

// controller-runtime fake client for isDrainAllowed (List)
crClient := fake.NewClientBuilder().
WithScheme(scheme.Scheme).
WithRuntimeObjects(cpNode1, cpNode2, workerNode).
Build()

// kube fake clientset for drain.RunCordonOrUncordon (node update)
kubeClient := kubefake.NewSimpleClientset(cpNode1.DeepCopy(), cpNode2.DeepCopy(), workerNode.DeepCopy())

// Bridge: when the kube client cordons a node via patch (sets Unschedulable=true),
// also update the controller-runtime fake client so isDrainAllowed sees it.
// This simulates the informer cache catching up after a direct API write.
// RunCordonOrUncordon uses a strategic merge patch, so we intercept "patch" actions.
kubeClient.PrependReactor("patch", "nodes", func(action clienttesting.Action) (bool, runtime.Object, error) {
patchAction, ok := action.(clienttesting.PatchAction)
if !ok {
return false, nil, nil
}
nodeName := patchAction.GetName()

// Propagate the cordon to the controller-runtime fake client
existing := &corev1.Node{}
if err := crClient.Get(context.Background(), types.NamespacedName{Name: nodeName}, existing); err == nil {
existing.Spec.Unschedulable = true
if err := crClient.Update(context.Background(), existing); err != nil {
t.Logf("bridge: failed to update CR client for node %s: %v", nodeName, err)
}
}
return false, nil, nil // fall through to the default handler
})

d := &machineDrainController{
Client: crClient,
}

makeDrainer := func() *drain.Helper {
return &drain.Helper{
Ctx: context.Background(),
Client: kubeClient,
Force: true,
IgnoreAllDaemonSets: true,
DeleteEmptyDirData: true,
GracePeriodSeconds: -1,
Timeout: 20 * time.Second,
Out: writer{t.Log},
ErrOut: writer{t.Log},
}
}

var wg sync.WaitGroup
errs := make([]error, 2)

wg.Add(2)
go func() {
defer wg.Done()
errs[0] = d.cordonNode(context.Background(), makeDrainer(), cpNode1)
}()
go func() {
defer wg.Done()
errs[1] = d.cordonNode(context.Background(), makeDrainer(), cpNode2)
}()
wg.Wait()

// Exactly one should succeed and one should get a RequeueAfterError,
// because the mutex serializes the check+cordon and the second goroutine
// sees the first node as already cordoned.
succeeded := 0
requeued := 0
for _, err := range errs {
if err == nil {
succeeded++
} else {
g.Expect(err).To(MatchError(ContainSubstring("drain not permitted")))
requeued++
}
}

g.Expect(succeeded).To(Equal(1), "exactly one CP drain should succeed")
g.Expect(requeued).To(Equal(1), "exactly one CP drain should be requeued")
}