From cb09aaafd12e9f46fce9dd05454985f98588a068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:18:44 -0300 Subject: [PATCH 01/10] feat: add spew to debug structs --- go.mod | 1 + go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index f1c92c6..ef4b30b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/mateusmlo/taskqueue go 1.24.1 require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect golang.org/x/net v0.42.0 // indirect golang.org/x/sys v0.34.0 // indirect diff --git a/go.sum b/go.sum index eb70b16..b7a8fd2 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= From ed3dc0e360ce8141b764c8eb51d2fa0196d170d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:19:18 -0300 Subject: [PATCH 02/10] chore(protobuf): refactor task result struct --- proto/taskqueue.pb.go | 59 +++++++++++++++++++------------------- proto/taskqueue.proto | 2 +- proto/taskqueue_grpc.pb.go | 2 +- 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/proto/taskqueue.pb.go b/proto/taskqueue.pb.go index 0b75982..f1a17eb 100644 --- a/proto/taskqueue.pb.go +++ b/proto/taskqueue.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.10 -// protoc v6.33.0 +// protoc v6.33.2 // source: proto/taskqueue.proto package proto @@ -499,7 +499,7 @@ func (x *GetTaskResultRequest) GetTaskId() string { type GetTaskResultResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - Task *Task `protobuf:"bytes,1,opt,name=task,proto3" json:"task,omitempty"` + Result []byte `protobuf:"bytes,1,opt,name=result,proto3" json:"result,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -534,9 +534,9 @@ func (*GetTaskResultResponse) Descriptor() ([]byte, []int) { return file_proto_taskqueue_proto_rawDescGZIP(), []int{6} } -func (x *GetTaskResultResponse) GetTask() *Task { +func (x *GetTaskResultResponse) GetResult() []byte { if x != nil { - return x.Task + return x.Result } return nil } @@ -1089,9 +1089,9 @@ const file_proto_taskqueue_proto_rawDesc = "" + "\x06status\x18\x01 \x01(\x0e2\x15.taskqueue.TaskStatusR\x06status\x12\x14\n" + "\x05error\x18\x02 \x01(\tR\x05error\"/\n" + "\x14GetTaskResultRequest\x12\x17\n" + - "\atask_id\x18\x01 \x01(\tR\x06taskId\"<\n" + - "\x15GetTaskResultResponse\x12#\n" + - "\x04task\x18\x01 \x01(\v2\x0f.taskqueue.TaskR\x04task\"\xa4\x02\n" + + "\atask_id\x18\x01 \x01(\tR\x06taskId\"/\n" + + "\x15GetTaskResultResponse\x12\x16\n" + + "\x06result\x18\x01 \x01(\fR\x06result\"\xa4\x02\n" + "\x06Worker\x12\x1b\n" + "\tworker_id\x18\x01 \x01(\tR\bworkerId\x12\x1d\n" + "\n" + @@ -1195,29 +1195,28 @@ var file_proto_taskqueue_proto_depIdxs = []int32{ 19, // 3: taskqueue.Task.completed_at:type_name -> google.protobuf.Timestamp 0, // 4: taskqueue.Task.status:type_name -> taskqueue.TaskStatus 0, // 5: taskqueue.GetTaskStatusResponse.status:type_name -> taskqueue.TaskStatus - 2, // 6: taskqueue.GetTaskResultResponse.task:type_name -> taskqueue.Task - 18, // 7: taskqueue.Worker.metadata:type_name -> taskqueue.Worker.MetadataEntry - 9, // 8: taskqueue.RegisterWorkerRequest.worker:type_name -> taskqueue.Worker - 2, // 9: taskqueue.FetchTaskResponse.task:type_name -> taskqueue.Task - 3, // 10: taskqueue.TaskQueue.SubmitTask:input_type -> taskqueue.SubmitTaskRequest - 7, // 11: taskqueue.TaskQueue.GetTaskResult:input_type -> taskqueue.GetTaskResultRequest - 5, // 12: taskqueue.TaskQueue.GetTaskStatus:input_type -> taskqueue.GetTaskStatusRequest - 10, // 13: taskqueue.WorkerService.RegisterWorker:input_type -> taskqueue.RegisterWorkerRequest - 14, // 14: taskqueue.WorkerService.FetchTask:input_type -> taskqueue.FetchTaskRequest - 16, // 15: taskqueue.WorkerService.SubmitResult:input_type -> taskqueue.SubmitResultRequest - 12, // 16: taskqueue.WorkerService.Heartbeat:input_type -> taskqueue.HeartbeatRequest - 4, // 17: taskqueue.TaskQueue.SubmitTask:output_type -> taskqueue.SubmitTaskResponse - 8, // 18: taskqueue.TaskQueue.GetTaskResult:output_type -> taskqueue.GetTaskResultResponse - 6, // 19: taskqueue.TaskQueue.GetTaskStatus:output_type -> taskqueue.GetTaskStatusResponse - 11, // 20: taskqueue.WorkerService.RegisterWorker:output_type -> taskqueue.RegisterWorkerResponse - 15, // 21: taskqueue.WorkerService.FetchTask:output_type -> taskqueue.FetchTaskResponse - 17, // 22: taskqueue.WorkerService.SubmitResult:output_type -> taskqueue.SubmitResultResponse - 13, // 23: taskqueue.WorkerService.Heartbeat:output_type -> taskqueue.HeartbeatResponse - 17, // [17:24] is the sub-list for method output_type - 10, // [10:17] is the sub-list for method input_type - 10, // [10:10] is the sub-list for extension type_name - 10, // [10:10] is the sub-list for extension extendee - 0, // [0:10] is the sub-list for field type_name + 18, // 6: taskqueue.Worker.metadata:type_name -> taskqueue.Worker.MetadataEntry + 9, // 7: taskqueue.RegisterWorkerRequest.worker:type_name -> taskqueue.Worker + 2, // 8: taskqueue.FetchTaskResponse.task:type_name -> taskqueue.Task + 3, // 9: taskqueue.TaskQueue.SubmitTask:input_type -> taskqueue.SubmitTaskRequest + 7, // 10: taskqueue.TaskQueue.GetTaskResult:input_type -> taskqueue.GetTaskResultRequest + 5, // 11: taskqueue.TaskQueue.GetTaskStatus:input_type -> taskqueue.GetTaskStatusRequest + 10, // 12: taskqueue.WorkerService.RegisterWorker:input_type -> taskqueue.RegisterWorkerRequest + 14, // 13: taskqueue.WorkerService.FetchTask:input_type -> taskqueue.FetchTaskRequest + 16, // 14: taskqueue.WorkerService.SubmitResult:input_type -> taskqueue.SubmitResultRequest + 12, // 15: taskqueue.WorkerService.Heartbeat:input_type -> taskqueue.HeartbeatRequest + 4, // 16: taskqueue.TaskQueue.SubmitTask:output_type -> taskqueue.SubmitTaskResponse + 8, // 17: taskqueue.TaskQueue.GetTaskResult:output_type -> taskqueue.GetTaskResultResponse + 6, // 18: taskqueue.TaskQueue.GetTaskStatus:output_type -> taskqueue.GetTaskStatusResponse + 11, // 19: taskqueue.WorkerService.RegisterWorker:output_type -> taskqueue.RegisterWorkerResponse + 15, // 20: taskqueue.WorkerService.FetchTask:output_type -> taskqueue.FetchTaskResponse + 17, // 21: taskqueue.WorkerService.SubmitResult:output_type -> taskqueue.SubmitResultResponse + 13, // 22: taskqueue.WorkerService.Heartbeat:output_type -> taskqueue.HeartbeatResponse + 16, // [16:23] is the sub-list for method output_type + 9, // [9:16] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_proto_taskqueue_proto_init() } diff --git a/proto/taskqueue.proto b/proto/taskqueue.proto index bc6a4ba..5a7b2a2 100644 --- a/proto/taskqueue.proto +++ b/proto/taskqueue.proto @@ -65,7 +65,7 @@ message GetTaskResultRequest { } message GetTaskResultResponse { - Task task = 1; + bytes result = 1; } // ============================================ diff --git a/proto/taskqueue_grpc.pb.go b/proto/taskqueue_grpc.pb.go index aa0bb53..c4e3b65 100644 --- a/proto/taskqueue_grpc.pb.go +++ b/proto/taskqueue_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v6.33.0 +// - protoc v6.33.2 // source: proto/taskqueue.proto package proto From 2e7131a56a0e9c3eb3e2b203b9d07d232aa11e85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:20:08 -0300 Subject: [PATCH 03/10] chore: minor type fixes --- internal/server/server.go | 2 +- internal/server/worker_info.go | 2 +- internal/worker/worker.go | 19 +++++++++++++++++++ internal/worker/worker_test.go | 14 +++++++------- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index d92d62c..adcf578 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -151,7 +151,7 @@ func (s *Server) GetTaskResult(ctx context.Context, req *proto.GetTaskResultRequ return nil, status.Errorf(codes.FailedPrecondition, "task %s not completed yet", req.TaskId) } - return &proto.GetTaskResultResponse{Task: task.toProtoTask()}, nil + return &proto.GetTaskResultResponse{Result: task.Result}, nil } // RegisterWorker handles worker registration requests diff --git a/internal/server/worker_info.go b/internal/server/worker_info.go index dc4ec59..b13b3ee 100644 --- a/internal/server/worker_info.go +++ b/internal/server/worker_info.go @@ -9,7 +9,7 @@ import ( "google.golang.org/grpc/status" ) -// WorkerInfo tracks a registered worker (server-side only) +// WorkerInfo tracks a registered worker's information type WorkerInfo struct { ID string Address string diff --git a/internal/worker/worker.go b/internal/worker/worker.go index b1f0fbe..541dd1a 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -2,6 +2,7 @@ package worker import ( "context" + "fmt" "log" "sync" "time" @@ -32,6 +33,7 @@ type TaskHandler interface { Handle(ctx context.Context, payload []byte) ([]byte, error) } +// NewWorker creates a new Worker instance. func NewWorker(serverAddr string, capacity int) *Worker { ctx, cancel := context.WithCancel(context.Background()) @@ -44,10 +46,12 @@ func NewWorker(serverAddr string, capacity int) *Worker { } } +// RegisterHandler registers a task handler for a specific task type. func (w *Worker) RegisterHandler(taskType string, handler TaskHandler) { w.handlers[taskType] = handler } +// Start connects the worker to the server and begins processing tasks. func (w *Worker) Start() error { tcr, err := credentials.NewClientTLSFromFile("./cert/server.crt", "localhost") if err != nil { @@ -74,6 +78,7 @@ func (w *Worker) Start() error { return nil } +// Stop stops the worker and cleans up resources. func (w *Worker) Stop() { w.cancel() w.wg.Wait() @@ -87,6 +92,7 @@ func (w *Worker) Stop() { } } +// heartbeatLoop sends periodic heartbeat messages to the server. func (w *Worker) heartbeatLoop() { defer w.wg.Done() @@ -107,6 +113,7 @@ func (w *Worker) heartbeatLoop() { } } +// fetchLoop continuously fetches tasks from the server and processes them. func (w *Worker) fetchLoop() { defer w.wg.Done() @@ -144,6 +151,7 @@ func (w *Worker) fetchLoop() { } } +// getTaskHandler returns a function that processes a task using the provided handler. func (w *Worker) getTaskHandler(handler TaskHandler) func(task *proto.Task) { return func(task *proto.Task) { defer w.decrementLoad() @@ -167,6 +175,7 @@ func (w *Worker) getTaskHandler(handler TaskHandler) func(task *proto.Task) { } } +// getCurrentLoad safely retrieves the current load of the worker. func (w *Worker) getCurrentLoad() int32 { w.loadMux.RLock() defer w.loadMux.RUnlock() @@ -174,6 +183,7 @@ func (w *Worker) getCurrentLoad() int32 { return int32(w.currentLoad) } +// incrementLoad safely increments the current load of the worker. func (w *Worker) incrementLoad() { w.loadMux.Lock() defer w.loadMux.Unlock() @@ -181,6 +191,7 @@ func (w *Worker) incrementLoad() { w.currentLoad++ } +// decrementLoad safely decrements the current load of the worker. func (w *Worker) decrementLoad() { w.loadMux.Lock() defer w.loadMux.Unlock() @@ -190,6 +201,7 @@ func (w *Worker) decrementLoad() { } } +// register registers the worker with the server and obtains a worker ID. func (w *Worker) register() error { req := w.buildRegisterRequest() @@ -202,6 +214,7 @@ func (w *Worker) register() error { return nil } +// buildRegisterRequest constructs the RegisterWorkerRequest message. func (w *Worker) buildRegisterRequest() *proto.RegisterWorkerRequest { taskTypes := make([]string, 0, len(w.handlers)) for taskType := range w.handlers { @@ -219,6 +232,7 @@ func (w *Worker) buildRegisterRequest() *proto.RegisterWorkerRequest { } } +// buildFetchTasksRequest constructs the FetchTaskRequest message. func (w *Worker) buildFetchTasksRequest() *proto.FetchTaskRequest { taskTypes := make([]string, 0, len(w.handlers)) for taskType := range w.handlers { @@ -231,9 +245,14 @@ func (w *Worker) buildFetchTasksRequest() *proto.FetchTaskRequest { } } +// buildHeartbeatRequest constructs the HeartbeatRequest message. func (w *Worker) buildHeartbeatRequest() *proto.HeartbeatRequest { return &proto.HeartbeatRequest{ WorkerId: w.id, CurrentLoad: w.getCurrentLoad(), } } + +func (w *Worker) GetWorkerID() string { + return fmt.Sprintf("Worker:%s", w.id) +} diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index 110e34e..ccf245b 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -172,7 +172,7 @@ func TestIncrementDecrementLoadConcurrency(t *testing.T) { done := make(chan bool) // Simulate concurrent increments - for i := 0; i < 50; i++ { + for range 50 { go func() { worker.incrementLoad() done <- true @@ -180,7 +180,7 @@ func TestIncrementDecrementLoadConcurrency(t *testing.T) { } // Wait for all increments - for i := 0; i < 50; i++ { + for range 50 { <-done } @@ -189,7 +189,7 @@ func TestIncrementDecrementLoadConcurrency(t *testing.T) { } // Simulate concurrent decrements - for i := 0; i < 30; i++ { + for range 30 { go func() { worker.decrementLoad() done <- true @@ -197,7 +197,7 @@ func TestIncrementDecrementLoadConcurrency(t *testing.T) { } // Wait for all decrements - for i := 0; i < 30; i++ { + for range 30 { <-done } @@ -415,7 +415,7 @@ func TestLoadMutexProtection(t *testing.T) { iterations := 100 // Concurrent readers - for i := 0; i < iterations; i++ { + for range iterations { go func() { _ = worker.getCurrentLoad() done <- true @@ -423,7 +423,7 @@ func TestLoadMutexProtection(t *testing.T) { } // Concurrent writers (increments) - for i := 0; i < iterations; i++ { + for range iterations { go func() { worker.incrementLoad() done <- true @@ -431,7 +431,7 @@ func TestLoadMutexProtection(t *testing.T) { } // Concurrent writers (decrements) - for i := 0; i < iterations; i++ { + for range iterations { go func() { worker.decrementLoad() done <- true From 7c84b356ad59fc34d8d1e8e06df1607db0bc5eb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:20:41 -0300 Subject: [PATCH 04/10] feat(internal): add graceful shutdown helper --- internal/helper/helper.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 internal/helper/helper.go diff --git a/internal/helper/helper.go b/internal/helper/helper.go new file mode 100644 index 0000000..99c175f --- /dev/null +++ b/internal/helper/helper.go @@ -0,0 +1,23 @@ +package helper + +import ( + "log" + "os" + "os/signal" + "syscall" +) + +type EventCaller int + +func SetupGracefulShutdown(shutdownFn func(), caller string) { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + + log.Printf("🔌 Shutting down %s...", caller) + shutdownFn() + os.Exit(0) + }() +} From c8d2d948866ac7a8d188fcc06b14758cb60fc56e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:21:13 -0300 Subject: [PATCH 05/10] feat(cmd/server): implement grpc server --- cmd/server/main_test.go | 576 ++++++++++++++++++++++++++++++++++++++++ cmd/server/server.go | 38 +++ 2 files changed, 614 insertions(+) create mode 100644 cmd/server/main_test.go create mode 100644 cmd/server/server.go diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go new file mode 100644 index 0000000..32759b5 --- /dev/null +++ b/cmd/server/main_test.go @@ -0,0 +1,576 @@ +package main + +import ( + "context" + "net" + "os" + "syscall" + "testing" + "time" + + "github.com/mateusmlo/taskqueue/internal/helper" + "github.com/mateusmlo/taskqueue/internal/server" + pb "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +const bufSize = 1024 * 1024 + +// bufDialer creates a gRPC connection using an in-memory buffer connection +func bufDialer(listener *bufconn.Listener) func(context.Context, string) (net.Conn, error) { + return func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + } +} + +// TestServerInitialization tests that the server can be initialized without errors +func TestServerInitialization(t *testing.T) { + s := server.NewServer() + + if s == nil { + t.Fatal("NewServer() returned nil") + } + + // Verify server is properly initialized + if s == nil { + t.Error("Server instance is nil") + } +} + +// TestGRPCServerRegistration tests that both services are properly registered +func TestGRPCServerRegistration(t *testing.T) { + // Create in-memory listener for testing + listener := bufconn.Listen(bufSize) + defer listener.Close() + + // Create server and gRPC server + s := server.NewServer() + grpcServer := grpc.NewServer() + + // Register both services + pb.RegisterTaskQueueServer(grpcServer, s) + pb.RegisterWorkerServiceServer(grpcServer, s) + + // Start server in background + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + defer grpcServer.Stop() + + // Create client connection + ctx := context.Background() + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer conn.Close() + + // Test TaskQueue service + taskQueueClient := pb.NewTaskQueueClient(conn) + submitResp, err := taskQueueClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "test-task", + Payload: []byte("test payload"), + Priority: int32(pb.Priority_HIGH), + MaxRetries: 3, + }) + if err != nil { + t.Errorf("SubmitTask failed: %v", err) + } + if submitResp == nil { + t.Error("SubmitTask returned nil response") + } + if submitResp != nil && submitResp.TaskId == "" { + t.Error("SubmitTask returned empty task ID") + } + + // Test WorkerService + workerClient := pb.NewWorkerServiceClient(conn) + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"test-task"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Errorf("RegisterWorker failed: %v", err) + } + if registerResp == nil { + t.Error("RegisterWorker returned nil response") + } + if registerResp != nil && !registerResp.Success { + t.Error("RegisterWorker returned success=false") + } + if registerResp != nil && registerResp.WorkerId == "" { + t.Error("RegisterWorker returned empty worker ID") + } +} + +// TestTaskQueueServiceEndpoints tests all TaskQueue service endpoints +func TestTaskQueueServiceEndpoints(t *testing.T) { + listener := bufconn.Listen(bufSize) + defer listener.Close() + + s := server.NewServer() + grpcServer := grpc.NewServer() + pb.RegisterTaskQueueServer(grpcServer, s) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + defer grpcServer.Stop() + + ctx := context.Background() + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer conn.Close() + + client := pb.NewTaskQueueClient(conn) + + // Test SubmitTask + submitResp, err := client.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "integration-test", + Payload: []byte("test data"), + Priority: int32(pb.Priority_MEDIUM), + MaxRetries: 5, + }) + if err != nil { + t.Fatalf("SubmitTask failed: %v", err) + } + if submitResp.TaskId == "" { + t.Fatal("SubmitTask returned empty task ID") + } + + taskID := submitResp.TaskId + + // Test GetTaskStatus + statusResp, err := client.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Errorf("GetTaskStatus failed: %v", err) + } + if statusResp == nil { + t.Fatal("GetTaskStatus returned nil response") + } + if statusResp.Status != pb.TaskStatus_PENDING { + t.Errorf("GetTaskStatus status = %v, want PENDING", statusResp.Status) + } + + // Test GetTaskResult on pending task (should fail with FailedPrecondition) + _, err = client.GetTaskResult(ctx, &pb.GetTaskResultRequest{ + TaskId: taskID, + }) + if err == nil { + t.Error("GetTaskResult expected error for pending task, got nil") + } + // Note: Full GetTaskResult success test is in TestWorkerTaskLifecycle +} + +// TestWorkerServiceEndpoints tests all WorkerService endpoints +func TestWorkerServiceEndpoints(t *testing.T) { + listener := bufconn.Listen(bufSize) + defer listener.Close() + + s := server.NewServer() + grpcServer := grpc.NewServer() + pb.RegisterWorkerServiceServer(grpcServer, s) + pb.RegisterTaskQueueServer(grpcServer, s) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + defer grpcServer.Stop() + + ctx := context.Background() + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer conn.Close() + + workerClient := pb.NewWorkerServiceClient(conn) + taskClient := pb.NewTaskQueueClient(conn) + + // Test RegisterWorker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"test-task", "batch-job"}, + Capacity: 15, + Metadata: map[string]string{ + "address": "localhost:9000", + "region": "us-east-1", + }, + }, + }) + if err != nil { + t.Fatalf("RegisterWorker failed: %v", err) + } + if registerResp.WorkerId == "" { + t.Fatal("RegisterWorker returned empty worker ID") + } + if !registerResp.Success { + t.Error("RegisterWorker success = false, want true") + } + + workerID := registerResp.WorkerId + + // Test Heartbeat + heartbeatResp, err := workerClient.Heartbeat(ctx, &pb.HeartbeatRequest{ + WorkerId: workerID, + CurrentLoad: 5, + }) + if err != nil { + t.Errorf("Heartbeat failed: %v", err) + } + if heartbeatResp == nil { + t.Fatal("Heartbeat returned nil response") + } + if !heartbeatResp.Success { + t.Error("Heartbeat success = false, want true") + } + if heartbeatResp.CurrentLoad != 5 { + t.Errorf("Heartbeat current_load = %v, want 5", heartbeatResp.CurrentLoad) + } + + // Submit a task for FetchTask test + submitResp, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "test-task", + Payload: []byte("worker test"), + Priority: int32(pb.Priority_HIGH), + MaxRetries: 3, + }) + if err != nil { + t.Fatalf("SubmitTask failed: %v", err) + } + + taskID := submitResp.TaskId + + // Test FetchTask + fetchResp, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: workerID, + TaskTypes: []string{"test-task"}, + }) + if err != nil { + t.Errorf("FetchTask failed: %v", err) + } + if fetchResp == nil { + t.Fatal("FetchTask returned nil response") + } + if !fetchResp.HasTask { + t.Error("FetchTask has_task = false, want true") + } + if fetchResp.Task == nil { + t.Fatal("FetchTask returned nil task") + } + if fetchResp.Task.Id != taskID { + t.Errorf("FetchTask task ID = %v, want %v", fetchResp.Task.Id, taskID) + } + + // Test SubmitResult + submitResultResp, err := workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: []byte("task completed successfully"), + Error: "", + }) + if err != nil { + t.Errorf("SubmitResult failed: %v", err) + } + if submitResultResp == nil { + t.Fatal("SubmitResult returned nil response") + } + if !submitResultResp.Success { + t.Error("SubmitResult success = false, want true") + } +} + +// TestGracefulShutdown tests the graceful shutdown mechanism +func TestGracefulShutdown(t *testing.T) { + grpcServer := grpc.NewServer() + defer grpcServer.Stop() + + // Setup graceful shutdown + helper.SetupGracefulShutdown(grpcServer.GracefulStop, "TEST") + + // Create a listener to start the server + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + + // Start server in background + serverStarted := make(chan struct{}) + go func() { + close(serverStarted) + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server stopped: %v", err) + } + }() + + // Wait for server to start + <-serverStarted + time.Sleep(100 * time.Millisecond) + + // Send SIGTERM to trigger graceful shutdown + process, err := os.FindProcess(os.Getpid()) + if err != nil { + t.Fatalf("Failed to find process: %v", err) + } + + // We can't actually test the full shutdown flow without exiting the test, + // but we can verify the signal handler is set up + err = process.Signal(syscall.Signal(0)) // Signal 0 checks if process exists + if err != nil { + t.Errorf("Process check failed: %v", err) + } + + // Clean shutdown for test + grpcServer.GracefulStop() +} + +// TestServerWithMultipleClients tests concurrent client connections +func TestServerWithMultipleClients(t *testing.T) { + listener := bufconn.Listen(bufSize) + defer listener.Close() + + s := server.NewServer() + grpcServer := grpc.NewServer() + pb.RegisterTaskQueueServer(grpcServer, s) + pb.RegisterWorkerServiceServer(grpcServer, s) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + defer grpcServer.Stop() + + const numClients = 10 + errors := make(chan error, numClients) + taskIDs := make(chan string, numClients) + + // Spawn multiple clients submitting tasks concurrently + for i := range numClients { + go func(clientID int) { + ctx := context.Background() + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + errors <- err + return + } + defer conn.Close() + + client := pb.NewTaskQueueClient(conn) + resp, err := client.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "concurrent-test", + Payload: []byte("test data"), + Priority: int32(pb.Priority_MEDIUM), + MaxRetries: 3, + }) + + if err != nil { + errors <- err + } else { + taskIDs <- resp.TaskId + errors <- nil + } + }(i) + } + + // Collect results + uniqueTaskIDs := make(map[string]bool) + for range numClients { + err := <-errors + if err != nil { + t.Errorf("Client request failed: %v", err) + continue + } + + taskID := <-taskIDs + if uniqueTaskIDs[taskID] { + t.Errorf("Duplicate task ID: %s", taskID) + } + uniqueTaskIDs[taskID] = true + } + + if len(uniqueTaskIDs) != numClients { + t.Errorf("Expected %d unique task IDs, got %d", numClients, len(uniqueTaskIDs)) + } +} + +// TestWorkerTaskLifecycle tests complete task lifecycle with worker +func TestWorkerTaskLifecycle(t *testing.T) { + listener := bufconn.Listen(bufSize) + defer listener.Close() + + s := server.NewServer() + grpcServer := grpc.NewServer() + pb.RegisterTaskQueueServer(grpcServer, s) + pb.RegisterWorkerServiceServer(grpcServer, s) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + defer grpcServer.Stop() + + ctx := context.Background() + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer conn.Close() + + taskClient := pb.NewTaskQueueClient(conn) + workerClient := pb.NewWorkerServiceClient(conn) + + // 1. Register worker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"lifecycle-test"}, + Capacity: 5, + Metadata: map[string]string{ + "address": "localhost:8000", + }, + }, + }) + if err != nil { + t.Fatalf("RegisterWorker failed: %v", err) + } + workerID := registerResp.WorkerId + + // 2. Submit task + submitResp, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "lifecycle-test", + Payload: []byte("lifecycle payload"), + Priority: int32(pb.Priority_HIGH), + MaxRetries: 3, + }) + if err != nil { + t.Fatalf("SubmitTask failed: %v", err) + } + taskID := submitResp.TaskId + + // 3. Check task is pending + statusResp, err := taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("GetTaskStatus failed: %v", err) + } + if statusResp.Status != pb.TaskStatus_PENDING { + t.Errorf("Task status = %v, want PENDING", statusResp.Status) + } + + // 4. Worker fetches task + fetchResp, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: workerID, + TaskTypes: []string{"lifecycle-test"}, + }) + if err != nil { + t.Fatalf("FetchTask failed: %v", err) + } + if !fetchResp.HasTask { + t.Fatal("FetchTask has_task = false, expected task to be available") + } + if fetchResp.Task.Id != taskID { + t.Errorf("Fetched task ID = %v, want %v", fetchResp.Task.Id, taskID) + } + + // 5. Check task is running + statusResp, err = taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("GetTaskStatus failed: %v", err) + } + if statusResp.Status != pb.TaskStatus_RUNNING { + t.Errorf("Task status = %v, want RUNNING", statusResp.Status) + } + + // 6. Worker sends heartbeat + heartbeatResp, err := workerClient.Heartbeat(ctx, &pb.HeartbeatRequest{ + WorkerId: workerID, + CurrentLoad: 1, + }) + if err != nil { + t.Fatalf("Heartbeat failed: %v", err) + } + if !heartbeatResp.Success { + t.Error("Heartbeat success = false, want true") + } + + // 7. Worker submits result + submitResultResp, err := workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: []byte("task result data"), + Error: "", + }) + if err != nil { + t.Fatalf("SubmitResult failed: %v", err) + } + if !submitResultResp.Success { + t.Error("SubmitResult success = false, want true") + } + + // 8. Check task is completed + statusResp, err = taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("GetTaskStatus failed: %v", err) + } + if statusResp.Status != pb.TaskStatus_COMPLETED { + t.Errorf("Task status = %v, want COMPLETED", statusResp.Status) + } + + // 9. Get task result + resultResp, err := taskClient.GetTaskResult(ctx, &pb.GetTaskResultRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("GetTaskResult failed: %v", err) + } + if resultResp.Task == nil { + t.Fatal("GetTaskResult returned nil task") + } + if resultResp.Task.Id != taskID { + t.Errorf("Result task ID = %v, want %v", resultResp.Task.Id, taskID) + } + if resultResp.Task.Status != pb.TaskStatus_COMPLETED { + t.Errorf("Result task status = %v, want COMPLETED", resultResp.Task.Status) + } +} diff --git a/cmd/server/server.go b/cmd/server/server.go new file mode 100644 index 0000000..0659b17 --- /dev/null +++ b/cmd/server/server.go @@ -0,0 +1,38 @@ +package main + +import ( + "log" + "net" + + "github.com/mateusmlo/taskqueue/internal/helper" + "github.com/mateusmlo/taskqueue/internal/server" + pb "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +func main() { + s := server.NewServer() + + tc, err := credentials.NewServerTLSFromFile("cert/server.crt", "cert/server.key") + if err != nil { + panic(err) + } + + grpcServer := grpc.NewServer(grpc.Creds(tc)) + + pb.RegisterTaskQueueServer(grpcServer, s) + pb.RegisterWorkerServiceServer(grpcServer, s) + + listener, err := net.Listen("tcp", ":50051") + if err != nil { + panic(err) + } + + helper.SetupGracefulShutdown(grpcServer.GracefulStop, "SERVER") + + log.Println("✅ Server listening on :50051") + if err := grpcServer.Serve(listener); err != nil { + panic(err) + } +} From f78145680199f2003d84b03d317a1448eae68870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:21:36 -0300 Subject: [PATCH 06/10] feat(cmd/worker): implement exmaple worker --- cmd/worker/worker.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 cmd/worker/worker.go diff --git a/cmd/worker/worker.go b/cmd/worker/worker.go new file mode 100644 index 0000000..402e9f2 --- /dev/null +++ b/cmd/worker/worker.go @@ -0,0 +1,36 @@ +package main + +import ( + "context" + "slices" + + "github.com/mateusmlo/taskqueue/internal/helper" + "github.com/mateusmlo/taskqueue/internal/worker" +) + +// An example task handler that just reverses a provided string +type ReverseStringHandler struct { +} + +// Handle reverses the input string +func (rh *ReverseStringHandler) Handle(ctx context.Context, payload []byte) ([]byte, error) { + s := slices.Clone(payload) + slices.Reverse(s) + + return s, nil +} + +func main() { + worker := worker.NewWorker("localhost:50051", 10) + + worker.RegisterHandler("reverseStr", &ReverseStringHandler{}) + + if err := worker.Start(); err != nil { + panic(err) + } + + helper.SetupGracefulShutdown(worker.Stop, worker.GetWorkerID()) + + // Keep the main goroutine running until shutdown signal received + select {} +} From 980958cb112a4fc92ae5ac474eb56b9bc4750b12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Mon, 15 Dec 2025 19:21:51 -0300 Subject: [PATCH 07/10] feat(cmd/client): implement example grpc client --- cmd/client/client.go | 73 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 cmd/client/client.go diff --git a/cmd/client/client.go b/cmd/client/client.go new file mode 100644 index 0000000..1f3795f --- /dev/null +++ b/cmd/client/client.go @@ -0,0 +1,73 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + pb "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +func main() { + ctx := context.Background() + tc, err := credentials.NewClientTLSFromFile("cert/server.crt", "localhost") + if err != nil { + fmt.Printf("Failed to load TLS credentials: %s\n", err.Error()) + panic(err) + } + + clientConn, err := grpc.NewClient("localhost:50051", grpc.WithTransportCredentials(tc)) + if err != nil { + fmt.Printf("Failed to connect to server: %s\n", err.Error()) + panic(err) + } + defer clientConn.Close() + + taskClient := pb.NewTaskQueueClient(clientConn) + + res, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "reverseStr", + Payload: []byte("hello world"), + Priority: int32(pb.Priority_HIGH), + MaxRetries: 3, + }) + + if err != nil { + log.Fatalf("Task failed to submit: %s\n", err.Error()) + } + + taskID := res.TaskId + + ctxDeadline, cancelFunc := context.WithTimeout(ctx, 10*time.Second) + + for { + fmt.Print("Polling task status...\n") + + taskStatusRes, err := taskClient.GetTaskStatus(ctxDeadline, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + + if err != nil { + log.Fatalf("Failed to get task status: %s\n", err.Error()) + } + + if taskStatusRes.Status == pb.TaskStatus_COMPLETED { + log.Print("Task completed, fetching result...\n") + cancelFunc() + break + } + + time.Sleep(1 * time.Second) + } + + taskRes, err := taskClient.GetTaskResult(ctx, &pb.GetTaskResultRequest{TaskId: taskID}) + if err != nil { + log.Fatalf("Failed to get task result: %s\n", err.Error()) + panic(err) + } + + fmt.Printf("Task result: %s\n", taskRes.GetResult()) +} From 5e659dc4a4883c6e3ff52d1d718744096bf5c1d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Tue, 16 Dec 2025 17:07:50 -0300 Subject: [PATCH 08/10] feat(client): refactor code to be more decoupled --- cmd/client/client.go | 104 +++++++++++++++++++++++++++++++------------ 1 file changed, 75 insertions(+), 29 deletions(-) diff --git a/cmd/client/client.go b/cmd/client/client.go index 1f3795f..b7ade03 100644 --- a/cmd/client/client.go +++ b/cmd/client/client.go @@ -11,63 +11,109 @@ import ( "google.golang.org/grpc/credentials" ) -func main() { - ctx := context.Background() - tc, err := credentials.NewClientTLSFromFile("cert/server.crt", "localhost") +// createGRPCConnection creates a gRPC connection with TLS credentials +func createGRPCConnection(certPath, serverName, address string) (*grpc.ClientConn, error) { + tc, err := credentials.NewClientTLSFromFile(certPath, serverName) if err != nil { - fmt.Printf("Failed to load TLS credentials: %s\n", err.Error()) - panic(err) + return nil, fmt.Errorf("failed to load TLS credentials: %w", err) } - clientConn, err := grpc.NewClient("localhost:50051", grpc.WithTransportCredentials(tc)) + clientConn, err := grpc.NewClient(address, grpc.WithTransportCredentials(tc)) if err != nil { - fmt.Printf("Failed to connect to server: %s\n", err.Error()) - panic(err) + return nil, fmt.Errorf("failed to connect to server: %w", err) } - defer clientConn.Close() - taskClient := pb.NewTaskQueueClient(clientConn) + return clientConn, nil +} - res, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ - Type: "reverseStr", - Payload: []byte("hello world"), - Priority: int32(pb.Priority_HIGH), - MaxRetries: 3, +// submitTask submits a task to the task queue and returns the task ID +func submitTask(ctx context.Context, client pb.TaskQueueClient, taskType string, payload []byte, priority pb.Priority, maxRetries int32) (string, error) { + res, err := client.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: taskType, + Payload: payload, + Priority: int32(priority), + MaxRetries: maxRetries, }) if err != nil { - log.Fatalf("Task failed to submit: %s\n", err.Error()) + return "", fmt.Errorf("task failed to submit: %w", err) } - taskID := res.TaskId - - ctxDeadline, cancelFunc := context.WithTimeout(ctx, 10*time.Second) + return res.TaskId, nil +} +// pollTaskUntilComplete polls the task status until it completes or times out +func pollTaskUntilComplete(ctx context.Context, client pb.TaskQueueClient, taskID string, pollInterval time.Duration) error { for { - fmt.Print("Polling task status...\n") - - taskStatusRes, err := taskClient.GetTaskStatus(ctxDeadline, &pb.GetTaskStatusRequest{ + taskStatusRes, err := client.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ TaskId: taskID, }) if err != nil { - log.Fatalf("Failed to get task status: %s\n", err.Error()) + return fmt.Errorf("failed to get task status: %w", err) } if taskStatusRes.Status == pb.TaskStatus_COMPLETED { - log.Print("Task completed, fetching result...\n") - cancelFunc() - break + return nil + } + + if taskStatusRes.Status == pb.TaskStatus_FAILED { + return fmt.Errorf("task failed") } - time.Sleep(1 * time.Second) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(pollInterval): + // Continue polling + } } +} - taskRes, err := taskClient.GetTaskResult(ctx, &pb.GetTaskResultRequest{TaskId: taskID}) +// getTaskResult retrieves the task result +func getTaskResult(ctx context.Context, client pb.TaskQueueClient, taskID string) (*pb.GetTaskResultResponse, error) { + taskRes, err := client.GetTaskResult(ctx, &pb.GetTaskResultRequest{TaskId: taskID}) if err != nil { - log.Fatalf("Failed to get task result: %s\n", err.Error()) + return nil, fmt.Errorf("failed to get task result: %w", err) + } + + return taskRes, nil +} + +func main() { + ctx := context.Background() + + clientConn, err := createGRPCConnection("cert/server.crt", "localhost", "localhost:50051") + if err != nil { + fmt.Printf("Connection error: %s\n", err.Error()) panic(err) } + defer clientConn.Close() + + taskClient := pb.NewTaskQueueClient(clientConn) + + taskID, err := submitTask(ctx, taskClient, "reverseStr", []byte("hello world"), pb.Priority_HIGH, 3) + if err != nil { + log.Fatalf("Submit error: %s\n", err.Error()) + } + + fmt.Printf("Task submitted with ID: %s\n", taskID) + + ctxDeadline, cancelFunc := context.WithTimeout(ctx, 10*time.Second) + defer cancelFunc() + + fmt.Println("Polling task status...") + err = pollTaskUntilComplete(ctxDeadline, taskClient, taskID, 1*time.Second) + if err != nil { + log.Fatalf("Polling error: %s\n", err.Error()) + } + + log.Println("Task completed, fetching result...") + + taskRes, err := getTaskResult(ctx, taskClient, taskID) + if err != nil { + log.Fatalf("Get result error: %s\n", err.Error()) + } fmt.Printf("Task result: %s\n", taskRes.GetResult()) } From e6c5776daa3cf3448102d533b537708fb3bc0ea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Tue, 16 Dec 2025 17:08:01 -0300 Subject: [PATCH 09/10] feat: adds unit tests --- cmd/client/client_test.go | 515 ++++++++++++++++++++++++++++++++ cmd/server/main_test.go | 10 +- cmd/worker/worker_test.go | 526 +++++++++++++++++++++++++++++++++ internal/helper/helper_test.go | 358 ++++++++++++++++++++++ internal/server/server_test.go | 11 +- 5 files changed, 1403 insertions(+), 17 deletions(-) create mode 100644 cmd/client/client_test.go create mode 100644 cmd/worker/worker_test.go create mode 100644 internal/helper/helper_test.go diff --git a/cmd/client/client_test.go b/cmd/client/client_test.go new file mode 100644 index 0000000..79a9288 --- /dev/null +++ b/cmd/client/client_test.go @@ -0,0 +1,515 @@ +package main + +import ( + "context" + "net" + "testing" + "time" + + "github.com/mateusmlo/taskqueue/internal/server" + pb "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +const bufSize = 1024 * 1024 + +// setupTestServer creates an in-memory test server for testing +func setupTestServer(t *testing.T) (*bufconn.Listener, *grpc.Server, pb.TaskQueueClient, func()) { + listener := bufconn.Listen(bufSize) + + s := server.NewServer() + grpcServer := grpc.NewServer() + pb.RegisterTaskQueueServer(grpcServer, s) + pb.RegisterWorkerServiceServer(grpcServer, s) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + client := pb.NewTaskQueueClient(conn) + + cleanup := func() { + conn.Close() + grpcServer.Stop() + listener.Close() + } + + return listener, grpcServer, client, cleanup +} + +// TestSubmitTask tests the submitTask function +func TestSubmitTask(t *testing.T) { + _, _, client, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + tests := []struct { + name string + taskType string + payload []byte + priority pb.Priority + maxRetries int32 + expectError bool + }{ + { + name: "valid task submission", + taskType: "test-task", + payload: []byte("test payload"), + priority: pb.Priority_HIGH, + maxRetries: 3, + expectError: false, + }, + { + name: "task with medium priority", + taskType: "batch-job", + payload: []byte("large payload"), + priority: pb.Priority_MEDIUM, + maxRetries: 5, + expectError: false, + }, + { + name: "task with low priority", + taskType: "background-task", + payload: []byte("small payload"), + priority: pb.Priority_LOW, + maxRetries: 1, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + taskID, err := submitTask(ctx, client, tt.taskType, tt.payload, tt.priority, tt.maxRetries) + + if tt.expectError && err == nil { + t.Error("Expected error but got nil") + } + + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !tt.expectError && taskID == "" { + t.Error("Expected non-empty task ID") + } + + // Verify task was actually submitted by checking status + if !tt.expectError && taskID != "" { + statusResp, err := client.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Errorf("Failed to get task status: %v", err) + } + if statusResp.Status != pb.TaskStatus_PENDING { + t.Errorf("Expected status PENDING, got %v", statusResp.Status) + } + } + }) + } +} + +// TestPollTaskUntilComplete tests the polling functionality +func TestPollTaskUntilComplete(t *testing.T) { + _, _, taskClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Submit a task first + taskID, err := submitTask(ctx, taskClient, "poll-test", []byte("test"), pb.Priority_HIGH, 3) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + t.Run("timeout on pending task", func(t *testing.T) { + // Create a context with very short timeout + ctxTimeout, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + err := pollTaskUntilComplete(ctxTimeout, taskClient, taskID, 50*time.Millisecond) + if err == nil { + t.Error("Expected timeout error but got nil") + } + if err != context.DeadlineExceeded { + t.Errorf("Expected context.DeadlineExceeded, got %v", err) + } + }) + + t.Run("invalid task ID", func(t *testing.T) { + err := pollTaskUntilComplete(ctx, taskClient, "invalid-task-id", 50*time.Millisecond) + if err == nil { + t.Error("Expected error for invalid task ID but got nil") + } + }) +} + +// TestPollTaskUntilCompleteSuccess tests successful polling +func TestPollTaskUntilCompleteSuccess(t *testing.T) { + listener, _, taskClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Create worker client + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create worker client: %v", err) + } + defer conn.Close() + + workerClient := pb.NewWorkerServiceClient(conn) + + // Register a worker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"complete-test"}, + Capacity: 5, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + // Submit a task + taskID, err := submitTask(ctx, taskClient, "complete-test", []byte("test"), pb.Priority_HIGH, 3) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + // Worker fetches and completes the task in background + go func() { + time.Sleep(100 * time.Millisecond) + + // Fetch task + _, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: registerResp.WorkerId, + TaskTypes: []string{"complete-test"}, + }) + if err != nil { + t.Logf("FetchTask error: %v", err) + return + } + + // Submit result + _, err = workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: []byte("completed"), + Error: "", + }) + if err != nil { + t.Logf("SubmitResult error: %v", err) + } + }() + + // Poll for completion + ctxTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + err = pollTaskUntilComplete(ctxTimeout, taskClient, taskID, 50*time.Millisecond) + if err != nil { + t.Errorf("Expected successful polling, got error: %v", err) + } +} + +// TestGetTaskResult tests retrieving task results +func TestGetTaskResult(t *testing.T) { + listener, _, taskClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Create worker client + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create worker client: %v", err) + } + defer conn.Close() + + workerClient := pb.NewWorkerServiceClient(conn) + + // Register worker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"result-test"}, + Capacity: 5, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + t.Run("get result of completed task", func(t *testing.T) { + // Submit task + taskID, err := submitTask(ctx, taskClient, "result-test", []byte("test"), pb.Priority_HIGH, 3) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + // Worker fetches and completes task + _, err = workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: registerResp.WorkerId, + TaskTypes: []string{"result-test"}, + }) + if err != nil { + t.Fatalf("Failed to fetch task: %v", err) + } + + expectedResult := []byte("task completed successfully") + _, err = workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: expectedResult, + Error: "", + }) + if err != nil { + t.Fatalf("Failed to submit result: %v", err) + } + + // Get result + result, err := getTaskResult(ctx, taskClient, taskID) + if err != nil { + t.Errorf("Failed to get task result: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if string(result.Result) != string(expectedResult) { + t.Errorf("Expected result %s, got %s", expectedResult, result.Result) + } + + // Verify task is completed by checking status + statusResp, err := taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Errorf("Failed to get task status: %v", err) + } + if statusResp.Status != pb.TaskStatus_COMPLETED { + t.Errorf("Expected status COMPLETED, got %v", statusResp.Status) + } + }) + + t.Run("get result of pending task", func(t *testing.T) { + // Submit task but don't complete it + taskID, err := submitTask(ctx, taskClient, "result-test", []byte("test"), pb.Priority_HIGH, 3) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + // Try to get result of pending task + _, err = getTaskResult(ctx, taskClient, taskID) + if err == nil { + t.Error("Expected error when getting result of pending task") + } + }) + + t.Run("get result with invalid task ID", func(t *testing.T) { + _, err := getTaskResult(ctx, taskClient, "invalid-task-id") + if err == nil { + t.Error("Expected error for invalid task ID") + } + }) +} + +// TestSubmitTaskConcurrent tests concurrent task submissions +func TestSubmitTaskConcurrent(t *testing.T) { + _, _, client, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + numTasks := 10 + taskIDs := make(chan string, numTasks) + errors := make(chan error, numTasks) + + for i := range numTasks { + go func(taskNum int) { + taskID, err := submitTask( + ctx, + client, + "concurrent-test", + []byte("test payload"), + pb.Priority_MEDIUM, + 3, + ) + if err != nil { + errors <- err + } else { + taskIDs <- taskID + errors <- nil + } + }(i) + } + + // Collect results + uniqueTaskIDs := make(map[string]bool) + for range numTasks { + err := <-errors + if err != nil { + t.Errorf("Concurrent task submission failed: %v", err) + continue + } + + taskID := <-taskIDs + if taskID == "" { + t.Error("Got empty task ID") + continue + } + + if uniqueTaskIDs[taskID] { + t.Errorf("Duplicate task ID: %s", taskID) + } + uniqueTaskIDs[taskID] = true + } + + if len(uniqueTaskIDs) != numTasks { + t.Errorf("Expected %d unique task IDs, got %d", numTasks, len(uniqueTaskIDs)) + } +} + +// TestClientWorkflow tests the complete client workflow +func TestClientWorkflow(t *testing.T) { + listener, _, taskClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Create worker client + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create worker client: %v", err) + } + defer conn.Close() + + workerClient := pb.NewWorkerServiceClient(conn) + + // Register worker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"workflow-test"}, + Capacity: 5, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + // Step 1: Submit task + taskID, err := submitTask(ctx, taskClient, "workflow-test", []byte("hello world"), pb.Priority_HIGH, 3) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + if taskID == "" { + t.Fatal("Expected non-empty task ID") + } + + // Step 2: Worker processes task in background + go func() { + time.Sleep(200 * time.Millisecond) + + fetchResp, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: registerResp.WorkerId, + TaskTypes: []string{"workflow-test"}, + }) + if err != nil { + t.Logf("FetchTask error: %v", err) + return + } + + if !fetchResp.HasTask { + t.Log("No task available") + return + } + + // Simulate task processing + time.Sleep(100 * time.Millisecond) + + _, err = workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: []byte("dlrow olleh"), + Error: "", + }) + if err != nil { + t.Logf("SubmitResult error: %v", err) + } + }() + + // Step 3: Poll for completion + ctxTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + err = pollTaskUntilComplete(ctxTimeout, taskClient, taskID, 100*time.Millisecond) + if err != nil { + t.Fatalf("Failed to poll task: %v", err) + } + + // Step 4: Get result + result, err := getTaskResult(ctx, taskClient, taskID) + if err != nil { + t.Fatalf("Failed to get task result: %v", err) + } + + if result == nil { + t.Fatal("Expected non-nil result") + } + + if string(result.Result) != "dlrow olleh" { + t.Errorf("Expected result 'dlrow olleh', got '%s'", result.Result) + } + + // Verify task completed successfully + statusResp, err := taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Errorf("Failed to get task status: %v", err) + } + if statusResp.Status != pb.TaskStatus_COMPLETED { + t.Errorf("Expected status COMPLETED, got %v", statusResp.Status) + } +} diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go index 32759b5..9b01f5c 100644 --- a/cmd/server/main_test.go +++ b/cmd/server/main_test.go @@ -564,13 +564,7 @@ func TestWorkerTaskLifecycle(t *testing.T) { if err != nil { t.Fatalf("GetTaskResult failed: %v", err) } - if resultResp.Task == nil { - t.Fatal("GetTaskResult returned nil task") - } - if resultResp.Task.Id != taskID { - t.Errorf("Result task ID = %v, want %v", resultResp.Task.Id, taskID) - } - if resultResp.Task.Status != pb.TaskStatus_COMPLETED { - t.Errorf("Result task status = %v, want COMPLETED", resultResp.Task.Status) + if resultResp.Result == nil { + t.Fatal("GetTaskResult returned nil result") } } diff --git a/cmd/worker/worker_test.go b/cmd/worker/worker_test.go new file mode 100644 index 0000000..f1609a2 --- /dev/null +++ b/cmd/worker/worker_test.go @@ -0,0 +1,526 @@ +package main + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/mateusmlo/taskqueue/internal/server" + "github.com/mateusmlo/taskqueue/internal/worker" + pb "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +const bufSize = 1024 * 1024 + +// setupTestServer creates an in-memory test server for testing +func setupTestServer(t *testing.T) (*bufconn.Listener, *grpc.Server, pb.TaskQueueClient, pb.WorkerServiceClient, func()) { + listener := bufconn.Listen(bufSize) + + s := server.NewServer() + grpcServer := grpc.NewServer() + pb.RegisterTaskQueueServer(grpcServer, s) + pb.RegisterWorkerServiceServer(grpcServer, s) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server error: %v", err) + } + }() + + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + + taskClient := pb.NewTaskQueueClient(conn) + workerClient := pb.NewWorkerServiceClient(conn) + + cleanup := func() { + conn.Close() + grpcServer.Stop() + listener.Close() + } + + return listener, grpcServer, taskClient, workerClient, cleanup +} + +// TestReverseStringHandler tests the ReverseStringHandler +func TestReverseStringHandler(t *testing.T) { + handler := &ReverseStringHandler{} + ctx := context.Background() + + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "simple string", + input: []byte("hello"), + expected: []byte("olleh"), + }, + { + name: "empty string", + input: []byte(""), + expected: []byte(""), + }, + { + name: "single character", + input: []byte("a"), + expected: []byte("a"), + }, + { + name: "palindrome", + input: []byte("racecar"), + expected: []byte("racecar"), + }, + { + name: "string with spaces", + input: []byte("hello world"), + expected: []byte("dlrow olleh"), + }, + { + name: "numbers", + input: []byte("12345"), + expected: []byte("54321"), + }, + { + name: "special characters", + input: []byte("!@#$%"), + expected: []byte("%$#@!"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := handler.Handle(ctx, tt.input) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if string(result) != string(tt.expected) { + t.Errorf("Expected %s, got %s", tt.expected, result) + } + + // Verify original input is not modified + if tt.name == "simple string" && string(tt.input) != "hello" { + t.Error("Handler should not modify original input") + } + }) + } +} + +// TestReverseStringHandlerConcurrent tests the handler with concurrent requests +func TestReverseStringHandlerConcurrent(t *testing.T) { + handler := &ReverseStringHandler{} + ctx := context.Background() + + numGoroutines := 100 + results := make(chan error, numGoroutines) + + for i := range numGoroutines { + go func(n int) { + input := []byte("test") + expected := []byte("tset") + + result, err := handler.Handle(ctx, input) + if err != nil { + results <- err + return + } + + if string(result) != string(expected) { + results <- errors.New("incorrect result") + return + } + + results <- nil + }(i) + } + + for range numGoroutines { + if err := <-results; err != nil { + t.Errorf("Concurrent handler test failed: %v", err) + } + } +} + +// MockTaskHandler is a test handler that tracks invocations +type MockTaskHandler struct { + handleFunc func(ctx context.Context, payload []byte) ([]byte, error) + invocations int +} + +func (m *MockTaskHandler) Handle(ctx context.Context, payload []byte) ([]byte, error) { + m.invocations++ + if m.handleFunc != nil { + return m.handleFunc(ctx, payload) + } + return payload, nil +} + +// TestWorkerHandlerRegistration tests worker handler registration +func TestWorkerHandlerRegistration(t *testing.T) { + w := worker.NewWorker("localhost:50051", 10) + + handler1 := &ReverseStringHandler{} + handler2 := &MockTaskHandler{} + + w.RegisterHandler("reverseStr", handler1) + w.RegisterHandler("mockTask", handler2) +} + +// TestWorkerTaskProcessing tests end-to-end task processing with a worker +func TestWorkerTaskProcessing(t *testing.T) { + listener, _, taskClient, workerClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Create a connection for the worker to use + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create worker connection: %v", err) + } + defer conn.Close() + + // Register a worker with ReverseStringHandler + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"reverseStr"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + workerID := registerResp.WorkerId + + // Submit a task + submitResp, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "reverseStr", + Payload: []byte("hello world"), + Priority: int32(pb.Priority_HIGH), + MaxRetries: 3, + }) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + taskID := submitResp.TaskId + + // Fetch the task + fetchResp, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: workerID, + TaskTypes: []string{"reverseStr"}, + }) + if err != nil { + t.Fatalf("Failed to fetch task: %v", err) + } + + if !fetchResp.HasTask { + t.Fatal("Expected task to be available") + } + + // Process the task using the handler + handler := &ReverseStringHandler{} + result, err := handler.Handle(ctx, fetchResp.Task.Payload) + if err != nil { + t.Fatalf("Handler failed: %v", err) + } + + // Submit the result + _, err = workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: result, + Error: "", + }) + if err != nil { + t.Fatalf("Failed to submit result: %v", err) + } + + // Verify the task is completed + statusResp, err := taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("Failed to get task status: %v", err) + } + + if statusResp.Status != pb.TaskStatus_COMPLETED { + t.Errorf("Expected status COMPLETED, got %v", statusResp.Status) + } + + // Verify the result + resultResp, err := taskClient.GetTaskResult(ctx, &pb.GetTaskResultRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("Failed to get task result: %v", err) + } + + expectedResult := "dlrow olleh" + if string(resultResp.Result) != expectedResult { + t.Errorf("Expected result %s, got %s", expectedResult, resultResp.Result) + } +} + +// TestWorkerMultipleTaskTypes tests worker handling multiple task types +func TestWorkerMultipleTaskTypes(t *testing.T) { + listener, _, taskClient, workerClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Create worker connection + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create worker connection: %v", err) + } + defer conn.Close() + + // Register worker with multiple task types + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"reverseStr", "anotherTask"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + workerID := registerResp.WorkerId + + // Submit tasks of different types + taskTypes := []string{"reverseStr", "anotherTask"} + taskIDs := make([]string, len(taskTypes)) + + for i, taskType := range taskTypes { + submitResp, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: taskType, + Payload: []byte("test"), + Priority: int32(pb.Priority_MEDIUM), + MaxRetries: 3, + }) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + taskIDs[i] = submitResp.TaskId + } + + // Worker should be able to fetch both task types + fetchedTasks := make(map[string]bool) + + for range taskTypes { + fetchResp, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: workerID, + TaskTypes: taskTypes, + }) + if err != nil { + t.Fatalf("Failed to fetch task: %v", err) + } + + if fetchResp.HasTask { + fetchedTasks[fetchResp.Task.Type] = true + } + } + + if len(fetchedTasks) != len(taskTypes) { + t.Errorf("Expected to fetch %d different task types, got %d", len(taskTypes), len(fetchedTasks)) + } + + for _, taskType := range taskTypes { + if !fetchedTasks[taskType] { + t.Errorf("Did not fetch task of type %s", taskType) + } + } +} + +// TestWorkerHeartbeat tests worker heartbeat functionality +func TestWorkerHeartbeat(t *testing.T) { + _, _, _, workerClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Register worker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"test-task"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + workerID := registerResp.WorkerId + + // Send heartbeat + heartbeatResp, err := workerClient.Heartbeat(ctx, &pb.HeartbeatRequest{ + WorkerId: workerID, + CurrentLoad: 5, + }) + if err != nil { + t.Fatalf("Failed to send heartbeat: %v", err) + } + + if !heartbeatResp.Success { + t.Error("Expected heartbeat to succeed") + } + + if heartbeatResp.CurrentLoad != 5 { + t.Errorf("Expected current load 5, got %d", heartbeatResp.CurrentLoad) + } +} + +// TestWorkerErrorHandling tests how worker handles errors from task handlers +func TestWorkerErrorHandling(t *testing.T) { + listener, _, taskClient, workerClient, cleanup := setupTestServer(t) + defer cleanup() + + ctx := context.Background() + + // Create worker connection + conn, err := grpc.NewClient( + "passthrough:///bufnet", + grpc.WithContextDialer(func(ctx context.Context, url string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to create worker connection: %v", err) + } + defer conn.Close() + + // Register worker + registerResp, err := workerClient.RegisterWorker(ctx, &pb.RegisterWorkerRequest{ + Worker: &pb.Worker{ + TaskTypes: []string{"error-task"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + }) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + workerID := registerResp.WorkerId + + // Submit a task + submitResp, err := taskClient.SubmitTask(ctx, &pb.SubmitTaskRequest{ + Type: "error-task", + Payload: []byte("test"), + Priority: int32(pb.Priority_HIGH), + MaxRetries: 3, + }) + if err != nil { + t.Fatalf("Failed to submit task: %v", err) + } + + taskID := submitResp.TaskId + + // Fetch task + fetchResp, err := workerClient.FetchTask(ctx, &pb.FetchTaskRequest{ + WorkerId: workerID, + TaskTypes: []string{"error-task"}, + }) + if err != nil { + t.Fatalf("Failed to fetch task: %v", err) + } + + if !fetchResp.HasTask { + t.Fatal("Expected task to be available") + } + + // Simulate handler error by submitting error result + expectedError := "task processing failed" + _, err = workerClient.SubmitResult(ctx, &pb.SubmitResultRequest{ + TaskId: taskID, + Result: nil, + Error: expectedError, + }) + if err != nil { + t.Fatalf("Failed to submit error result: %v", err) + } + + // Verify task failed + time.Sleep(100 * time.Millisecond) + statusResp, err := taskClient.GetTaskStatus(ctx, &pb.GetTaskStatusRequest{ + TaskId: taskID, + }) + if err != nil { + t.Fatalf("Failed to get task status: %v", err) + } + + // Task should be marked as failed or moved back to pending for retry + if statusResp.Status != pb.TaskStatus_FAILED && statusResp.Status != pb.TaskStatus_PENDING { + t.Logf("Task status after error: %v (expected FAILED or PENDING)", statusResp.Status) + } +} + +// TestReverseStringHandlerLargeInput tests handler with large input +func TestReverseStringHandlerLargeInput(t *testing.T) { + handler := &ReverseStringHandler{} + ctx := context.Background() + + // Create a large input (1MB) + largeInput := make([]byte, 1024*1024) + for i := range largeInput { + largeInput[i] = byte(i % 256) + } + + result, err := handler.Handle(ctx, largeInput) + if err != nil { + t.Fatalf("Handler failed with large input: %v", err) + } + + if len(result) != len(largeInput) { + t.Errorf("Expected result length %d, got %d", len(largeInput), len(result)) + } + + // Verify reversal is correct + for i := range largeInput { + if result[i] != largeInput[len(largeInput)-1-i] { + t.Errorf("Incorrect reversal at index %d", i) + break + } + } +} diff --git a/internal/helper/helper_test.go b/internal/helper/helper_test.go new file mode 100644 index 0000000..9de70c6 --- /dev/null +++ b/internal/helper/helper_test.go @@ -0,0 +1,358 @@ +package helper + +import ( + "os" + "sync" + "syscall" + "testing" + "time" +) + +// TestSetupGracefulShutdown tests the graceful shutdown setup +func TestSetupGracefulShutdown(t *testing.T) { + var mu sync.Mutex + + shutdownFn := func() { + mu.Lock() + defer mu.Unlock() + // Shutdown logic would go here + } + + SetupGracefulShutdown(shutdownFn, "TEST") + + // Send SIGTERM to trigger shutdown + // Note: We can't easily test this without causing the test to exit + // because SetupGracefulShutdown calls os.Exit(0) + // So we'll just verify the function doesn't panic when called +} + +// TestSetupGracefulShutdownWithMultipleCalls tests multiple setups don't interfere +func TestSetupGracefulShutdownWithMultipleCalls(t *testing.T) { + callCount := 0 + var mu sync.Mutex + + shutdownFn := func() { + mu.Lock() + defer mu.Unlock() + callCount++ + } + + // Setup multiple times (simulating multiple components) + SetupGracefulShutdown(shutdownFn, "TEST1") + SetupGracefulShutdown(shutdownFn, "TEST2") + + // Both should be set up without errors + // In practice, this might create multiple signal handlers +} + +// TestSetupGracefulShutdownWithNilFunction tests that setup handles nil gracefully +func TestSetupGracefulShutdownWithNilFunction(t *testing.T) { + // This should not panic even with a nil function + // Though in practice, it will panic when the signal is received + defer func() { + if r := recover(); r != nil { + t.Errorf("SetupGracefulShutdown panicked with nil function during setup: %v", r) + } + }() + + // Setup should complete without panic + SetupGracefulShutdown(nil, "TEST") +} + +// TestSetupGracefulShutdownWithEmptyCaller tests setup with empty caller string +func TestSetupGracefulShutdownWithEmptyCaller(t *testing.T) { + var mu sync.Mutex + + shutdownFn := func() { + mu.Lock() + defer mu.Unlock() + // Shutdown logic + } + + // Should work with empty caller + SetupGracefulShutdown(shutdownFn, "") +} + +// TestSetupGracefulShutdownSignalHandling tests that signals are properly registered +func TestSetupGracefulShutdownSignalHandling(t *testing.T) { + // This test verifies that the signal handler is set up + // We create a custom implementation that doesn't call os.Exit + // to test the signal handling mechanism + + shutdownCalled := false + var mu sync.Mutex + done := make(chan bool, 1) + + shutdownFn := func() { + mu.Lock() + shutdownCalled = true + mu.Unlock() + done <- true + } + + // Create a custom signal handler (similar to SetupGracefulShutdown but testable) + c := make(chan os.Signal, 1) + // We don't use signal.Notify here to avoid interfering with actual signal handling + + // Simulate what happens when a signal is received + go func() { + // Simulate receiving a signal + c <- syscall.SIGTERM + + // Call the shutdown function + shutdownFn() + }() + + // Wait for shutdown to be called + select { + case <-done: + mu.Lock() + if !shutdownCalled { + t.Error("Shutdown function was not called") + } + mu.Unlock() + case <-time.After(2 * time.Second): + t.Error("Timeout waiting for shutdown") + } +} + +// TestSetupGracefulShutdownWithSIGINT tests handling of SIGINT +func TestSetupGracefulShutdownWithSIGINT(t *testing.T) { + // Similar to above but with SIGINT + shutdownCalled := false + var mu sync.Mutex + done := make(chan bool, 1) + + shutdownFn := func() { + mu.Lock() + shutdownCalled = true + mu.Unlock() + done <- true + } + + c := make(chan os.Signal, 1) + + go func() { + c <- os.Interrupt + shutdownFn() + }() + + select { + case <-done: + mu.Lock() + if !shutdownCalled { + t.Error("Shutdown function was not called for SIGINT") + } + mu.Unlock() + case <-time.After(2 * time.Second): + t.Error("Timeout waiting for shutdown with SIGINT") + } +} + +// TestSetupGracefulShutdownConcurrent tests concurrent signal handling +func TestSetupGracefulShutdownConcurrent(t *testing.T) { + callCount := 0 + var mu sync.Mutex + done := make(chan bool, 10) + + shutdownFn := func() { + mu.Lock() + callCount++ + mu.Unlock() + done <- true + } + + // Simulate multiple shutdown calls (which might happen with multiple signals) + numCalls := 10 + for range numCalls { + go func() { + shutdownFn() + }() + } + + // Wait for all calls to complete + for range numCalls { + select { + case <-done: + // Success + case <-time.After(2 * time.Second): + t.Error("Timeout waiting for concurrent shutdown calls") + return + } + } + + mu.Lock() + if callCount != numCalls { + t.Errorf("Expected %d shutdown calls, got %d", numCalls, callCount) + } + mu.Unlock() +} + +// TestSetupGracefulShutdownWithLongRunningShutdown tests shutdown with long-running cleanup +func TestSetupGracefulShutdownWithLongRunningShutdown(t *testing.T) { + shutdownStarted := false + shutdownCompleted := false + var mu sync.Mutex + done := make(chan bool) + + shutdownFn := func() { + mu.Lock() + shutdownStarted = true + mu.Unlock() + + // Simulate long-running cleanup + time.Sleep(100 * time.Millisecond) + + mu.Lock() + shutdownCompleted = true + mu.Unlock() + + done <- true + } + + go func() { + shutdownFn() + }() + + select { + case <-done: + mu.Lock() + if !shutdownStarted { + t.Error("Shutdown did not start") + } + if !shutdownCompleted { + t.Error("Shutdown did not complete") + } + mu.Unlock() + case <-time.After(2 * time.Second): + t.Error("Timeout waiting for long-running shutdown") + } +} + +// TestSetupGracefulShutdownWithPanic tests that shutdown handles panics gracefully +func TestSetupGracefulShutdownWithPanic(t *testing.T) { + shutdownFn := func() { + panic("shutdown panic") + } + + // This should be handled by the caller or result in process termination + // We test that setting it up doesn't panic immediately + defer func() { + if r := recover(); r != nil { + t.Errorf("SetupGracefulShutdown panicked during setup: %v", r) + } + }() + + SetupGracefulShutdown(shutdownFn, "TEST") +} + +// TestSetupGracefulShutdownMultipleSignalTypes tests handling different signal types +func TestSetupGracefulShutdownMultipleSignalTypes(t *testing.T) { + signals := []os.Signal{ + os.Interrupt, + syscall.SIGTERM, + } + + for _, sig := range signals { + t.Run(sig.String(), func(t *testing.T) { + shutdownCalled := false + var mu sync.Mutex + done := make(chan bool, 1) + + shutdownFn := func() { + mu.Lock() + shutdownCalled = true + mu.Unlock() + done <- true + } + + // Simulate signal handling + go func() { + shutdownFn() + }() + + select { + case <-done: + mu.Lock() + if !shutdownCalled { + t.Errorf("Shutdown not called for signal %v", sig) + } + mu.Unlock() + case <-time.After(2 * time.Second): + t.Errorf("Timeout waiting for shutdown with signal %v", sig) + } + }) + } +} + +// TestSetupGracefulShutdownCallerIdentification tests different caller identifications +func TestSetupGracefulShutdownCallerIdentification(t *testing.T) { + callers := []string{ + "SERVER", + "WORKER", + "CLIENT", + "TEST", + "", + "Very Long Caller Name With Spaces And Special Characters !@#$%", + } + + for _, caller := range callers { + t.Run(caller, func(t *testing.T) { + var mu sync.Mutex + + shutdownFn := func() { + mu.Lock() + defer mu.Unlock() + // Shutdown logic + } + + // Should not panic with any caller name + defer func() { + if r := recover(); r != nil { + t.Errorf("Panicked with caller '%s': %v", caller, r) + } + }() + + SetupGracefulShutdown(shutdownFn, caller) + }) + } +} + +// TestEventCaller tests the EventCaller type +func TestEventCaller(t *testing.T) { + // EventCaller is defined but not currently used + // Test that it can be instantiated + var ec EventCaller + ec = 0 + + if ec != 0 { + t.Errorf("Expected EventCaller to be 0, got %d", ec) + } + + // Test that it can hold different values + ec = 1 + if ec != 1 { + t.Errorf("Expected EventCaller to be 1, got %d", ec) + } +} + +// BenchmarkSetupGracefulShutdown benchmarks the setup function +func BenchmarkSetupGracefulShutdown(b *testing.B) { + shutdownFn := func() {} + + for b.Loop() { + SetupGracefulShutdown(shutdownFn, "BENCHMARK") + } +} + +// BenchmarkShutdownFunction benchmarks the shutdown function execution +func BenchmarkShutdownFunction(b *testing.B) { + counter := 0 + shutdownFn := func() { + counter++ + } + + for b.Loop() { + shutdownFn() + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index c292c1d..2df49d6 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -404,6 +404,7 @@ func TestServer_GetTaskResult(t *testing.T) { s.tasksMux.Lock() completedTask := s.tasks[submitResp.TaskId] completedTask.Status = COMPLETED + completedTask.Result = []byte("task completed successfully") now := time.Now() completedTask.CompletedAt = &now s.tasksMux.Unlock() @@ -478,17 +479,9 @@ func TestServer_GetTaskResult(t *testing.T) { t.Fatal("GetTaskResult() returned nil response") } - if resp.Task == nil { + if resp.Result == nil { t.Fatal("GetTaskResult() returned nil task") } - - if resp.Task.Id != tt.taskID { - t.Errorf("GetTaskResult() task ID = %v, want %v", resp.Task.Id, tt.taskID) - } - - if resp.Task.Status != proto.TaskStatus_COMPLETED { - t.Errorf("GetTaskResult() task status = %v, want COMPLETED", resp.Task.Status) - } } }) } From 767362602cb9c1db8e69b0675be21724d2632a93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Mendon=C3=A7a?= Date: Tue, 16 Dec 2025 17:20:16 -0300 Subject: [PATCH 10/10] fix(server): prevents race condition w/ maps --- internal/server/server.go | 76 ++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index adcf578..aa2bad2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -132,26 +132,36 @@ func (s *Server) SubmitTask(ctx context.Context, req *proto.SubmitTaskRequest) ( // GetTaskStatus retrieves the status of a task by its ID func (s *Server) GetTaskStatus(ctx context.Context, req *proto.GetTaskStatusRequest) (*proto.GetTaskStatusResponse, error) { - task, err := s.findTask(req.TaskId) - if err != nil { - return nil, err + s.tasksMux.RLock() + task, exists := s.tasks[req.TaskId] + if !exists { + s.tasksMux.RUnlock() + return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId) } + taskStatus := task.Status + s.tasksMux.RUnlock() - return &proto.GetTaskStatusResponse{Status: proto.TaskStatus(task.Status)}, nil + return &proto.GetTaskStatusResponse{Status: proto.TaskStatus(taskStatus)}, nil } // GetTaskResult retrieves the result of a completed task by its ID func (s *Server) GetTaskResult(ctx context.Context, req *proto.GetTaskResultRequest) (*proto.GetTaskResultResponse, error) { - task, err := s.findTask(req.TaskId) - if err != nil { - return nil, err + s.tasksMux.RLock() + task, exists := s.tasks[req.TaskId] + if !exists { + s.tasksMux.RUnlock() + return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId) } if task.Status != COMPLETED { + s.tasksMux.RUnlock() return nil, status.Errorf(codes.FailedPrecondition, "task %s not completed yet", req.TaskId) } - return &proto.GetTaskResultResponse{Result: task.Result}, nil + result := task.Result + s.tasksMux.RUnlock() + + return &proto.GetTaskResultResponse{Result: result}, nil } // RegisterWorker handles worker registration requests @@ -169,28 +179,33 @@ func (s *Server) RegisterWorker(ctx context.Context, req *proto.RegisterWorkerRe // Heartbeat processes heartbeat messages from workers func (s *Server) Heartbeat(ctx context.Context, req *proto.HeartbeatRequest) (*proto.HeartbeatResponse, error) { - worker, err := s.findWorker(req.WorkerId) - if err != nil { - return nil, err + s.workersMux.Lock() + worker, exists := s.workers[req.WorkerId] + if !exists { + s.workersMux.Unlock() + return nil, status.Errorf(codes.NotFound, "worker %s not found", req.WorkerId) } worker.LastHeartbeat = time.Now() worker.CurrentLoad = int(req.CurrentLoad) + currentLoad := worker.CurrentLoad + s.workersMux.Unlock() - return &proto.HeartbeatResponse{Success: true, CurrentLoad: int32(worker.CurrentLoad)}, nil + return &proto.HeartbeatResponse{Success: true, CurrentLoad: int32(currentLoad)}, nil } // SubmitResult processes the result submission from workers func (s *Server) SubmitResult(ctx context.Context, req *proto.SubmitResultRequest) (*proto.SubmitResultResponse, error) { - task, err := s.findTask(req.TaskId) - if err != nil { - return nil, err + s.tasksMux.Lock() + task, exists := s.tasks[req.TaskId] + if !exists { + s.tasksMux.Unlock() + return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId) } now := time.Now() task.CompletedAt = &now - - defer s.decrementCurrentLoad(task.WorkerID) + workerID := task.WorkerID if req.Error != "" { task.Error = req.Error @@ -200,30 +215,43 @@ func (s *Server) SubmitResult(ctx context.Context, req *proto.SubmitResultReques task.Status = PENDING task.StartedAt = nil task.CompletedAt = nil + s.tasksMux.Unlock() s.appendTaskToQueue(task) + s.decrementCurrentLoad(workerID) + + return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil } else { task.Status = FAILED + s.tasksMux.Unlock() + s.decrementCurrentLoad(workerID) return nil, status.Errorf(codes.DeadlineExceeded, "task %s failed after maximum retries: %s", req.TaskId, req.Error) } } else { task.Status = COMPLETED task.Result = req.Result - } + s.tasksMux.Unlock() - return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil + s.decrementCurrentLoad(workerID) + return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil + } } func (s *Server) FetchTask(ctx context.Context, req *proto.FetchTaskRequest) (*proto.FetchTaskResponse, error) { - worker, err := s.findWorker(req.WorkerId) - if err != nil { - return nil, err + s.workersMux.RLock() + worker, exists := s.workers[req.WorkerId] + if !exists { + s.workersMux.RUnlock() + return nil, status.Errorf(codes.NotFound, "worker %s not found", req.WorkerId) } if worker.CurrentLoad >= worker.Capacity { + s.workersMux.RUnlock() return &proto.FetchTaskResponse{HasTask: false}, nil } + workerID := worker.ID + s.workersMux.RUnlock() s.queuesMux.Lock() defer s.queuesMux.Unlock() @@ -241,10 +269,10 @@ func (s *Server) FetchTask(ctx context.Context, req *proto.FetchTaskRequest) (*p s.tasksMux.Lock() task.Status = RUNNING task.StartedAt = &now - task.WorkerID = worker.ID + task.WorkerID = workerID s.tasksMux.Unlock() - s.incrementCurrentLoad(worker.ID) + s.incrementCurrentLoad(workerID) return &proto.FetchTaskResponse{Task: task.toProtoTask(), HasTask: true}, nil }