diff --git a/.gitignore b/.gitignore index 05b6621..df923dc 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,6 @@ build/ tmp/ temp/ *.tmp + +# local test certificates +cert/ diff --git a/internal/server/server.go b/internal/server/server.go index aa73137..d92d62c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,12 +7,10 @@ import ( "time" "github.com/google/uuid" + "github.com/mateusmlo/taskqueue/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/mateusmlo/taskqueue/internal/worker" - "github.com/mateusmlo/taskqueue/proto" ) type Priority int @@ -39,7 +37,7 @@ type Server struct { pendingQueues map[Priority][]*Task queuesMux sync.RWMutex - workers map[string]*worker.Worker + workers map[string]*WorkerInfo workersMux sync.RWMutex ctx context.Context @@ -73,7 +71,7 @@ func NewServer() *Server { return &Server{ tasks: make(map[string]*Task), pendingQueues: make(map[Priority][]*Task), - workers: make(map[string]*worker.Worker), + workers: make(map[string]*WorkerInfo), ctx: ctx, cancel: cancel, } @@ -158,7 +156,7 @@ func (s *Server) GetTaskResult(ctx context.Context, req *proto.GetTaskResultRequ // RegisterWorker handles worker registration requests func (s *Server) RegisterWorker(ctx context.Context, req *proto.RegisterWorkerRequest) (*proto.RegisterWorkerResponse, error) { - var newWorker worker.Worker + var newWorker WorkerInfo newWorker.FromProtoWorker(req.Worker) s.workersMux.Lock() @@ -300,7 +298,7 @@ func (s *Server) findTask(taskID string) (*Task, error) { } // findWorker retrieves a worker by its ID, returning an error if not found -func (s *Server) findWorker(workerID string) (*worker.Worker, error) { +func (s *Server) findWorker(workerID string) (*WorkerInfo, error) { s.workersMux.RLock() defer s.workersMux.RUnlock() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 5162258..c292c1d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/mateusmlo/taskqueue/internal/worker" "github.com/mateusmlo/taskqueue/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -1247,7 +1246,7 @@ func TestServer_UtilityFunctions(t *testing.T) { // Create a worker workerID := "test-worker-1" - s.workers[workerID] = &worker.Worker{ + s.workers[workerID] = &WorkerInfo{ ID: workerID, Capacity: 10, CurrentLoad: 5, @@ -1273,7 +1272,7 @@ func TestServer_UtilityFunctions(t *testing.T) { // Create a worker workerID := "test-worker-2" - s.workers[workerID] = &worker.Worker{ + s.workers[workerID] = &WorkerInfo{ ID: workerID, Capacity: 10, CurrentLoad: 5, @@ -1328,7 +1327,7 @@ func TestServer_UtilityFunctions(t *testing.T) { s := NewServer() defer s.cancel() - testWorker := &worker.Worker{ID: "test-worker-1"} + testWorker := &WorkerInfo{ID: "test-worker-1"} s.workers["test-worker-1"] = testWorker // Test finding existing worker diff --git a/internal/server/worker_info.go b/internal/server/worker_info.go new file mode 100644 index 0000000..dc4ec59 --- /dev/null +++ b/internal/server/worker_info.go @@ -0,0 +1,40 @@ +package server + +import ( + "time" + + "github.com/google/uuid" + "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// WorkerInfo tracks a registered worker (server-side only) +type WorkerInfo struct { + ID string + Address string + TaskTypes []string + Capacity int + CurrentLoad int + RegisteredAt time.Time + LastHeartbeat time.Time + Metadata map[string]string +} + +func (wi *WorkerInfo) FromProtoWorker(pw *proto.Worker) error { + uuid, err := uuid.NewV7() + if err != nil { + return status.Errorf(codes.Internal, "failed to generate worker UUID: %v", err) + } + + wi.ID = uuid.String() + wi.TaskTypes = pw.TaskTypes + wi.Address = pw.Metadata["address"] + wi.Capacity = int(pw.Capacity) + wi.CurrentLoad = 0 + wi.Metadata = pw.Metadata + wi.RegisteredAt = time.Now() + wi.LastHeartbeat = time.Now() + + return nil +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go index efd7667..b1f0fbe 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -1,40 +1,239 @@ package worker import ( + "context" + "log" + "sync" "time" - "github.com/google/uuid" "github.com/mateusmlo/taskqueue/proto" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) type Worker struct { - ID string - Address string - RegisteredAt time.Time - LastHeartbeat time.Time - TaskTypes []string - Capacity int - CurrentLoad int - Metadata map[string]string -} - -// FromProtoWorker initializes a Worker instance from a proto.Worker message (server generates ID) -func (w *Worker) FromProtoWorker(pw *proto.Worker) error { - uuid, err := uuid.NewV7() + serverAddr string + conn *grpc.ClientConn + client proto.WorkerServiceClient + + id string + capacity int + + handlers map[string]TaskHandler + currentLoad int + loadMux sync.RWMutex + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +type TaskHandler interface { + Handle(ctx context.Context, payload []byte) ([]byte, error) +} + +func NewWorker(serverAddr string, capacity int) *Worker { + ctx, cancel := context.WithCancel(context.Background()) + + return &Worker{ + serverAddr: serverAddr, + capacity: capacity, + handlers: make(map[string]TaskHandler), + ctx: ctx, + cancel: cancel, + } +} + +func (w *Worker) RegisterHandler(taskType string, handler TaskHandler) { + w.handlers[taskType] = handler +} + +func (w *Worker) Start() error { + tcr, err := credentials.NewClientTLSFromFile("./cert/server.crt", "localhost") if err != nil { - return status.Errorf(codes.Internal, "failed to generate worker UUID: %v", err) + return err + } + + conn, err := grpc.NewClient(w.serverAddr, grpc.WithTransportCredentials(tcr)) + if err != nil { + return err + } + + w.conn = conn + w.client = proto.NewWorkerServiceClient(w.conn) + + if err := w.register(); err != nil { + w.conn.Close() + return err + } + + w.wg.Add(2) + go w.heartbeatLoop() + go w.fetchLoop() + + return nil +} + +func (w *Worker) Stop() { + w.cancel() + w.wg.Wait() + + if w.conn != nil { + if err := w.conn.Close(); err != nil { + log.Printf("Error closing gRPC connection: %v", err) + } + + w.conn = nil + } +} + +func (w *Worker) heartbeatLoop() { + defer w.wg.Done() + + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + req := w.buildHeartbeatRequest() + _, err := w.client.Heartbeat(w.ctx, req) + if err != nil { + log.Printf("Worker heartbeat error: %v", err) + } + case <-w.ctx.Done(): + return + } + } +} + +func (w *Worker) fetchLoop() { + defer w.wg.Done() + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + req := w.buildFetchTasksRequest() + res, err := w.client.FetchTask(w.ctx, req) + if err != nil { + log.Printf("Worker fetch task error: %v", err) + continue + } + + if !res.HasTask { + continue + } + + handler, exists := w.handlers[res.Task.Type] + if !exists { + log.Printf("No handler registered for task type: %s", res.Task.Type) + continue + } + + w.incrementLoad() + + handleTask := w.getTaskHandler(handler) + + go handleTask(res.Task) + case <-w.ctx.Done(): + return + } } +} + +func (w *Worker) getTaskHandler(handler TaskHandler) func(task *proto.Task) { + return func(task *proto.Task) { + defer w.decrementLoad() + + result, err := handler.Handle(w.ctx, task.Payload) + submitReq := &proto.SubmitResultRequest{ + TaskId: task.Id, + } + if err != nil { + submitReq.Error = err.Error() + submitReq.Result = nil + } else { + submitReq.Error = "" + submitReq.Result = result + } - w.ID = uuid.String() - w.TaskTypes = pw.TaskTypes - w.Address = pw.Metadata["address"] - w.Capacity = int(pw.Capacity) - w.CurrentLoad = 0 - w.Metadata = pw.Metadata - w.RegisteredAt = time.Now() - w.LastHeartbeat = time.Now() + _, err = w.client.SubmitResult(w.ctx, submitReq) + if err != nil { + log.Printf("Error submitting task result: %v", err) + } + } +} +func (w *Worker) getCurrentLoad() int32 { + w.loadMux.RLock() + defer w.loadMux.RUnlock() + + return int32(w.currentLoad) +} + +func (w *Worker) incrementLoad() { + w.loadMux.Lock() + defer w.loadMux.Unlock() + + w.currentLoad++ +} + +func (w *Worker) decrementLoad() { + w.loadMux.Lock() + defer w.loadMux.Unlock() + + if w.currentLoad > 0 { + w.currentLoad-- + } +} + +func (w *Worker) register() error { + req := w.buildRegisterRequest() + + res, err := w.client.RegisterWorker(w.ctx, req) + if err != nil { + return err + } + + w.id = res.WorkerId return nil } + +func (w *Worker) buildRegisterRequest() *proto.RegisterWorkerRequest { + taskTypes := make([]string, 0, len(w.handlers)) + for taskType := range w.handlers { + taskTypes = append(taskTypes, taskType) + } + + return &proto.RegisterWorkerRequest{ + Worker: &proto.Worker{ + TaskTypes: taskTypes, + Capacity: int32(w.capacity), + Metadata: map[string]string{ + "address": w.serverAddr, + }, + }, + } +} + +func (w *Worker) buildFetchTasksRequest() *proto.FetchTaskRequest { + taskTypes := make([]string, 0, len(w.handlers)) + for taskType := range w.handlers { + taskTypes = append(taskTypes, taskType) + } + + return &proto.FetchTaskRequest{ + WorkerId: w.id, + TaskTypes: taskTypes, + } +} + +func (w *Worker) buildHeartbeatRequest() *proto.HeartbeatRequest { + return &proto.HeartbeatRequest{ + WorkerId: w.id, + CurrentLoad: w.getCurrentLoad(), + } +} diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go new file mode 100644 index 0000000..110e34e --- /dev/null +++ b/internal/worker/worker_test.go @@ -0,0 +1,506 @@ +package worker + +import ( + "context" + "errors" + "testing" + "time" +) + +// MockHandler is a mock implementation of TaskHandler for testing +type MockHandler struct { + handleFunc func(ctx context.Context, payload []byte) ([]byte, error) +} + +func (m *MockHandler) Handle(ctx context.Context, payload []byte) ([]byte, error) { + if m.handleFunc != nil { + return m.handleFunc(ctx, payload) + } + return payload, nil +} + +func TestNewWorker(t *testing.T) { + tests := []struct { + name string + serverAddr string + capacity int + }{ + { + name: "basic worker", + serverAddr: "localhost:8080", + capacity: 5, + }, + { + name: "high capacity worker", + serverAddr: "localhost:8080", + capacity: 10, + }, + { + name: "single capacity worker", + serverAddr: "localhost:8080", + capacity: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + worker := NewWorker(tt.serverAddr, tt.capacity) + + if worker == nil { + t.Fatal("NewWorker returned nil") + } + + if worker.serverAddr != tt.serverAddr { + t.Errorf("expected serverAddr %s, got %s", tt.serverAddr, worker.serverAddr) + } + + if worker.capacity != tt.capacity { + t.Errorf("expected capacity %d, got %d", tt.capacity, worker.capacity) + } + + if worker.handlers == nil { + t.Error("handlers map not initialized") + } + + if worker.ctx == nil { + t.Error("context not initialized") + } + + if worker.cancel == nil { + t.Error("cancel function not initialized") + } + + if worker.currentLoad != 0 { + t.Errorf("expected currentLoad 0, got %d", worker.currentLoad) + } + + // Verify handlers map is initialized empty + if len(worker.handlers) != 0 { + t.Errorf("expected empty handlers map, got %d handlers", len(worker.handlers)) + } + }) + } +} + +func TestRegisterHandler(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + handler1 := &MockHandler{} + handler2 := &MockHandler{} + + worker.RegisterHandler("task1", handler1) + worker.RegisterHandler("task2", handler2) + + if worker.handlers["task1"] != handler1 { + t.Error("handler1 not registered correctly") + } + + if worker.handlers["task2"] != handler2 { + t.Error("handler2 not registered correctly") + } +} + +func TestGetCurrentLoad(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + // Initial load should be 0 + if worker.getCurrentLoad() != 0 { + t.Errorf("expected initial load 0, got %d", worker.getCurrentLoad()) + } + + // Manually set load to test getter + worker.currentLoad = 3 + if worker.getCurrentLoad() != 3 { + t.Errorf("expected load 3, got %d", worker.getCurrentLoad()) + } +} + +func TestIncrementLoad(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + if worker.getCurrentLoad() != 0 { + t.Errorf("expected initial load 0, got %d", worker.getCurrentLoad()) + } + + worker.incrementLoad() + if worker.getCurrentLoad() != 1 { + t.Errorf("expected load 1, got %d", worker.getCurrentLoad()) + } + + worker.incrementLoad() + if worker.getCurrentLoad() != 2 { + t.Errorf("expected load 2, got %d", worker.getCurrentLoad()) + } + + worker.incrementLoad() + if worker.getCurrentLoad() != 3 { + t.Errorf("expected load 3, got %d", worker.getCurrentLoad()) + } +} + +func TestDecrementLoad(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + // Set initial load + worker.currentLoad = 3 + + worker.decrementLoad() + if worker.getCurrentLoad() != 2 { + t.Errorf("expected load 2, got %d", worker.getCurrentLoad()) + } + + worker.decrementLoad() + if worker.getCurrentLoad() != 1 { + t.Errorf("expected load 1, got %d", worker.getCurrentLoad()) + } + + worker.decrementLoad() + if worker.getCurrentLoad() != 0 { + t.Errorf("expected load 0, got %d", worker.getCurrentLoad()) + } + + // Decrementing below 0 should keep it at 0 + worker.decrementLoad() + if worker.getCurrentLoad() != 0 { + t.Errorf("expected load to stay at 0, got %d", worker.getCurrentLoad()) + } +} + +func TestIncrementDecrementLoadConcurrency(t *testing.T) { + worker := NewWorker("localhost:8080", 100) + + done := make(chan bool) + + // Simulate concurrent increments + for i := 0; i < 50; i++ { + go func() { + worker.incrementLoad() + done <- true + }() + } + + // Wait for all increments + for i := 0; i < 50; i++ { + <-done + } + + if worker.getCurrentLoad() != 50 { + t.Errorf("expected load 50 after increments, got %d", worker.getCurrentLoad()) + } + + // Simulate concurrent decrements + for i := 0; i < 30; i++ { + go func() { + worker.decrementLoad() + done <- true + }() + } + + // Wait for all decrements + for i := 0; i < 30; i++ { + <-done + } + + if worker.getCurrentLoad() != 20 { + t.Errorf("expected load 20 after decrements, got %d", worker.getCurrentLoad()) + } +} + +func TestBuildRegisterRequest(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + worker.RegisterHandler("task1", &MockHandler{}) + worker.RegisterHandler("task2", &MockHandler{}) + + req := worker.buildRegisterRequest() + + if req == nil { + t.Fatal("buildRegisterRequest returned nil") + } + + if req.Worker == nil { + t.Fatal("Worker in request is nil") + } + + if len(req.Worker.TaskTypes) != 2 { + t.Errorf("expected 2 task types, got %d", len(req.Worker.TaskTypes)) + } + + if req.Worker.Capacity != 5 { + t.Errorf("expected capacity 5, got %d", req.Worker.Capacity) + } + + if req.Worker.Metadata == nil { + t.Fatal("Metadata is nil") + } + + if req.Worker.Metadata["address"] != "localhost:8080" { + t.Errorf("expected address localhost:8080, got %s", req.Worker.Metadata["address"]) + } +} + +func TestBuildFetchTasksRequest(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + worker.RegisterHandler("task1", &MockHandler{}) + worker.RegisterHandler("task2", &MockHandler{}) + worker.id = "worker-123" + + req := worker.buildFetchTasksRequest() + + if req == nil { + t.Fatal("buildFetchTasksRequest returned nil") + } + + if req.WorkerId != "worker-123" { + t.Errorf("expected WorkerId worker-123, got %s", req.WorkerId) + } + + if len(req.TaskTypes) != 2 { + t.Errorf("expected 2 task types, got %d", len(req.TaskTypes)) + } +} + +func TestBuildHeartbeatRequest(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + worker.id = "worker-123" + worker.currentLoad = 3 + + req := worker.buildHeartbeatRequest() + + if req == nil { + t.Fatal("buildHeartbeatRequest returned nil") + } + + if req.WorkerId != "worker-123" { + t.Errorf("expected WorkerId worker-123, got %s", req.WorkerId) + } + + if req.CurrentLoad != 3 { + t.Errorf("expected CurrentLoad 3, got %d", req.CurrentLoad) + } +} + +func TestGetTaskHandler(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + expectedPayload := []byte("test payload") + expectedResult := []byte("test result") + + mockHandler := &MockHandler{ + handleFunc: func(ctx context.Context, payload []byte) ([]byte, error) { + if string(payload) != string(expectedPayload) { + t.Errorf("expected payload %s, got %s", expectedPayload, payload) + } + return expectedResult, nil + }, + } + + worker.RegisterHandler("task1", mockHandler) + + taskHandler := worker.getTaskHandler(mockHandler) + + if taskHandler == nil { + t.Fatal("getTaskHandler returned nil") + } + + // Note: We can't fully test the task handler without mocking the gRPC client + // but we can verify it returns a function +} + +func TestGetTaskHandlerWithError(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + expectedError := errors.New("handler error") + + mockHandler := &MockHandler{ + handleFunc: func(ctx context.Context, payload []byte) ([]byte, error) { + return nil, expectedError + }, + } + + taskHandler := worker.getTaskHandler(mockHandler) + + if taskHandler == nil { + t.Fatal("getTaskHandler returned nil") + } +} + +func TestStop(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + // Start a goroutine that should be cancelled by Stop + done := make(chan bool) + go func() { + <-worker.ctx.Done() + done <- true + }() + + worker.Stop() + + // Wait for context cancellation with timeout + select { + case <-done: + // Success - context was cancelled + case <-time.After(1 * time.Second): + t.Fatal("Context was not cancelled within timeout") + } +} + +func TestHandlersInitializedEmpty(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + // Verify handlers map is initialized but empty + if worker.handlers == nil { + t.Fatal("handlers map should be initialized") + } + + if len(worker.handlers) != 0 { + t.Errorf("handlers map should be empty, got %d handlers", len(worker.handlers)) + } +} + +func TestRegisterHandlerOverwrite(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + handler1 := &MockHandler{} + handler2 := &MockHandler{} + + worker.RegisterHandler("task1", handler1) + if worker.handlers["task1"] != handler1 { + t.Error("handler1 not registered correctly") + } + + // Overwrite with handler2 + worker.RegisterHandler("task1", handler2) + if worker.handlers["task1"] != handler2 { + t.Error("handler2 not registered correctly after overwrite") + } +} + +func TestWorkerContextCancellation(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + if worker.ctx == nil { + t.Fatal("worker context is nil") + } + + if worker.cancel == nil { + t.Fatal("worker cancel function is nil") + } + + // Initially context should not be cancelled + select { + case <-worker.ctx.Done(): + t.Fatal("Context should not be cancelled initially") + default: + // Good - context is not cancelled + } + + // Cancel the context + worker.cancel() + + // Now context should be cancelled + select { + case <-worker.ctx.Done(): + // Good - context is cancelled + case <-time.After(100 * time.Millisecond): + t.Fatal("Context was not cancelled") + } +} + +func TestLoadMutexProtection(t *testing.T) { + worker := NewWorker("localhost:8080", 100) + + // Test that concurrent reads and writes don't cause race conditions + done := make(chan bool) + iterations := 100 + + // Concurrent readers + for i := 0; i < iterations; i++ { + go func() { + _ = worker.getCurrentLoad() + done <- true + }() + } + + // Concurrent writers (increments) + for i := 0; i < iterations; i++ { + go func() { + worker.incrementLoad() + done <- true + }() + } + + // Concurrent writers (decrements) + for i := 0; i < iterations; i++ { + go func() { + worker.decrementLoad() + done <- true + }() + } + + // Wait for all operations + for i := 0; i < iterations*3; i++ { + <-done + } + + // Just verify we can read the final load without panic + finalLoad := worker.getCurrentLoad() + if finalLoad < 0 { + t.Errorf("final load should not be negative, got %d", finalLoad) + } +} + +func TestMockHandlerDefaultBehavior(t *testing.T) { + handler := &MockHandler{} + payload := []byte("test") + + result, err := handler.Handle(context.Background(), payload) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if string(result) != string(payload) { + t.Errorf("expected result %s, got %s", payload, result) + } +} + +func TestMockHandlerCustomBehavior(t *testing.T) { + expectedError := errors.New("custom error") + handler := &MockHandler{ + handleFunc: func(ctx context.Context, payload []byte) ([]byte, error) { + return nil, expectedError + }, + } + + result, err := handler.Handle(context.Background(), []byte("test")) + + if err != expectedError { + t.Errorf("expected error %v, got %v", expectedError, err) + } + + if result != nil { + t.Errorf("expected nil result, got %v", result) + } +} + +func TestWorkerHandlerRegistration(t *testing.T) { + worker := NewWorker("localhost:8080", 5) + + // Register multiple handlers + taskTypes := []string{"type1", "type2", "type3"} + for _, taskType := range taskTypes { + worker.RegisterHandler(taskType, &MockHandler{}) + } + + // Verify all handlers are registered + if len(worker.handlers) != len(taskTypes) { + t.Errorf("expected %d handlers, got %d", len(taskTypes), len(worker.handlers)) + } + + for _, taskType := range taskTypes { + if _, exists := worker.handlers[taskType]; !exists { + t.Errorf("handler for task type %s should be registered", taskType) + } + } +}