diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..05b6621 --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# macOS +.DS_Store +.AppleDouble +.LSOverride + +# Coverage files +coverage.out +*.coverprofile +*.coverage +coverage/ +*.lcov + +# Go build artifacts +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binaries +*.test + +# Go workspace file +go.work + +# Build directories +bin/ +dist/ +build/ + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ +.project +.classpath +.settings/ + +# Environment files +.env +.env.local +.env.*.local + +# Logs +*.log + +# Temporary files +tmp/ +temp/ +*.tmp diff --git a/internal/server/server.go b/internal/server/server.go index ed3b185..aa73137 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,7 +1,8 @@ -package internal +package server import ( "context" + "slices" "sync" "time" @@ -10,6 +11,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/mateusmlo/taskqueue/internal/worker" "github.com/mateusmlo/taskqueue/proto" ) @@ -29,6 +31,7 @@ const ( FAILED ) +// Server struct implements the TaskQueue and WorkerService gRPC servers type Server struct { tasks map[string]*Task tasksMux sync.RWMutex @@ -36,7 +39,7 @@ type Server struct { pendingQueues map[Priority][]*Task queuesMux sync.RWMutex - workers map[string]*Worker + workers map[string]*worker.Worker workersMux sync.RWMutex ctx context.Context @@ -46,6 +49,7 @@ type Server struct { proto.UnimplementedWorkerServiceServer } +// Task represents a unit of work in the task queue system type Task struct { ID string Type string @@ -62,23 +66,14 @@ type Task struct { WorkerID string } -type Worker struct { - ID string - Address string - RegisteredAt time.Time - LastHeartbeat time.Time - TaskTypes []string - Capacity int - CurrentLoad int -} - +// NewServer initializes and returns a new Server instance func NewServer() *Server { ctx, cancel := context.WithCancel(context.Background()) return &Server{ tasks: make(map[string]*Task), pendingQueues: make(map[Priority][]*Task), - workers: make(map[string]*Worker), + workers: make(map[string]*worker.Worker), ctx: ctx, cancel: cancel, } @@ -108,6 +103,7 @@ func (t *Task) toProtoTask() *proto.Task { return protoTask } +// SubmitTask handles task submission requests func (s *Server) SubmitTask(ctx context.Context, req *proto.SubmitTaskRequest) (*proto.SubmitTaskResponse, error) { uuid, err := uuid.NewV7() if err != nil { @@ -131,38 +127,187 @@ func (s *Server) SubmitTask(ctx context.Context, req *proto.SubmitTaskRequest) ( s.tasks[taskID] = newTask - s.queuesMux.Lock() - defer s.queuesMux.Unlock() - - s.pendingQueues[newTask.Priority] = append(s.pendingQueues[newTask.Priority], newTask) + s.appendTaskToQueue(newTask) return &proto.SubmitTaskResponse{TaskId: newTask.ID}, nil } +// GetTaskStatus retrieves the status of a task by its ID func (s *Server) GetTaskStatus(ctx context.Context, req *proto.GetTaskStatusRequest) (*proto.GetTaskStatusResponse, error) { - s.tasksMux.RLock() - defer s.tasksMux.RUnlock() - - task, exists := s.tasks[req.TaskId] - if !exists { - return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId) + task, err := s.findTask(req.TaskId) + if err != nil { + return nil, err } return &proto.GetTaskStatusResponse{Status: proto.TaskStatus(task.Status)}, nil } +// GetTaskResult retrieves the result of a completed task by its ID func (s *Server) GetTaskResult(ctx context.Context, req *proto.GetTaskResultRequest) (*proto.GetTaskResultResponse, error) { + task, err := s.findTask(req.TaskId) + if err != nil { + return nil, err + } + + if task.Status != COMPLETED { + return nil, status.Errorf(codes.FailedPrecondition, "task %s not completed yet", req.TaskId) + } + + return &proto.GetTaskResultResponse{Task: task.toProtoTask()}, nil +} + +// RegisterWorker handles worker registration requests +func (s *Server) RegisterWorker(ctx context.Context, req *proto.RegisterWorkerRequest) (*proto.RegisterWorkerResponse, error) { + var newWorker worker.Worker + newWorker.FromProtoWorker(req.Worker) + + s.workersMux.Lock() + defer s.workersMux.Unlock() + + s.workers[newWorker.ID] = &newWorker + + return &proto.RegisterWorkerResponse{WorkerId: newWorker.ID, Success: true}, nil +} + +// Heartbeat processes heartbeat messages from workers +func (s *Server) Heartbeat(ctx context.Context, req *proto.HeartbeatRequest) (*proto.HeartbeatResponse, error) { + worker, err := s.findWorker(req.WorkerId) + if err != nil { + return nil, err + } + + worker.LastHeartbeat = time.Now() + worker.CurrentLoad = int(req.CurrentLoad) + + return &proto.HeartbeatResponse{Success: true, CurrentLoad: int32(worker.CurrentLoad)}, nil +} + +// SubmitResult processes the result submission from workers +func (s *Server) SubmitResult(ctx context.Context, req *proto.SubmitResultRequest) (*proto.SubmitResultResponse, error) { + task, err := s.findTask(req.TaskId) + if err != nil { + return nil, err + } + + now := time.Now() + task.CompletedAt = &now + + defer s.decrementCurrentLoad(task.WorkerID) + + if req.Error != "" { + task.Error = req.Error + task.RetryCount++ + + if task.RetryCount < task.MaxRetries { + task.Status = PENDING + task.StartedAt = nil + task.CompletedAt = nil + + s.appendTaskToQueue(task) + } else { + task.Status = FAILED + + return nil, status.Errorf(codes.DeadlineExceeded, "task %s failed after maximum retries: %s", req.TaskId, req.Error) + } + } else { + task.Status = COMPLETED + task.Result = req.Result + } + + return &proto.SubmitResultResponse{Success: true, Result: req.Result}, nil +} + +func (s *Server) FetchTask(ctx context.Context, req *proto.FetchTaskRequest) (*proto.FetchTaskResponse, error) { + worker, err := s.findWorker(req.WorkerId) + if err != nil { + return nil, err + } + + if worker.CurrentLoad >= worker.Capacity { + return &proto.FetchTaskResponse{HasTask: false}, nil + } + + s.queuesMux.Lock() + defer s.queuesMux.Unlock() + + for priority := HIGH; priority <= LOW; priority++ { + queue := s.pendingQueues[priority] + for i, task := range queue { + if slices.Contains(req.TaskTypes, task.Type) { + // Remove task from queue + s.pendingQueues[priority] = append(queue[:i], queue[i+1:]...) + + // Update task status + now := time.Now() + + s.tasksMux.Lock() + task.Status = RUNNING + task.StartedAt = &now + task.WorkerID = worker.ID + s.tasksMux.Unlock() + + s.incrementCurrentLoad(worker.ID) + + return &proto.FetchTaskResponse{Task: task.toProtoTask(), HasTask: true}, nil + } + } + } + + return &proto.FetchTaskResponse{HasTask: false}, nil +} + +// Util functions + +// appendTaskToQueue appends a task back to the pending queue based on its priority +func (s *Server) appendTaskToQueue(task *Task) { + s.queuesMux.Lock() + defer s.queuesMux.Unlock() + + s.pendingQueues[task.Priority] = append(s.pendingQueues[task.Priority], task) +} + +// decrementCurrentLoad decreases the current load of the specified worker +func (s *Server) decrementCurrentLoad(workerID string) { + s.workersMux.Lock() + defer s.workersMux.Unlock() + + if worker, exists := s.workers[workerID]; exists { + worker.CurrentLoad-- + } +} + +// incrementCurrentLoad increases the current load of the specified worker +func (s *Server) incrementCurrentLoad(workerID string) { + s.workersMux.Lock() + defer s.workersMux.Unlock() + + if worker, exists := s.workers[workerID]; exists { + worker.CurrentLoad++ + } +} + +// findTask retrieves a task by its ID, returning an error if not found +func (s *Server) findTask(taskID string) (*Task, error) { s.tasksMux.RLock() defer s.tasksMux.RUnlock() - task, exists := s.tasks[req.TaskId] + task, exists := s.tasks[taskID] if !exists { - return nil, status.Errorf(codes.NotFound, "task %s not found", req.TaskId) + return nil, status.Errorf(codes.NotFound, "task %s not found", taskID) } - if task.Status != COMPLETED { - return nil, status.Errorf(codes.FailedPrecondition, "task %s not completed yet", req.TaskId) + return task, nil +} + +// findWorker retrieves a worker by its ID, returning an error if not found +func (s *Server) findWorker(workerID string) (*worker.Worker, error) { + s.workersMux.RLock() + defer s.workersMux.RUnlock() + + worker, exists := s.workers[workerID] + if !exists { + return nil, status.Errorf(codes.NotFound, "worker %s not registered", workerID) } - return &proto.GetTaskResultResponse{Task: task.toProtoTask()}, nil + return worker, nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 6442d2d..5162258 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,10 +1,11 @@ -package internal +package server import ( "context" "testing" "time" + "github.com/mateusmlo/taskqueue/internal/worker" "github.com/mateusmlo/taskqueue/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -587,3 +588,824 @@ func TestServer_ConcurrentGetTaskStatus(t *testing.T) { } } } + +func TestServer_RegisterWorker(t *testing.T) { + s := NewServer() + defer s.cancel() + ctx := context.Background() + + tests := []struct { + name string + worker *proto.Worker + wantErr bool + }{ + { + name: "valid worker registration", + worker: &proto.Worker{ + TaskTypes: []string{"image-processing", "data-export"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + "region": "us-west-1", + }, + }, + wantErr: false, + }, + { + name: "worker with single task type", + worker: &proto.Worker{ + TaskTypes: []string{"email-send"}, + Capacity: 5, + Metadata: map[string]string{ + "address": "localhost:8081", + }, + }, + wantErr: false, + }, + { + name: "worker with high capacity", + worker: &proto.Worker{ + TaskTypes: []string{"batch-job", "report-gen"}, + Capacity: 100, + Metadata: map[string]string{ + "address": "localhost:8082", + "type": "high-capacity", + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &proto.RegisterWorkerRequest{ + Worker: tt.worker, + } + + resp, err := s.RegisterWorker(ctx, req) + + if (err != nil) != tt.wantErr { + t.Errorf("RegisterWorker() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + if resp == nil { + t.Fatal("RegisterWorker() returned nil response") + } + + if !resp.Success { + t.Error("RegisterWorker() success = false, want true") + } + + if resp.WorkerId == "" { + t.Error("RegisterWorker() returned empty worker ID") + } + + // Verify worker was stored + s.workersMux.RLock() + worker, exists := s.workers[resp.WorkerId] + s.workersMux.RUnlock() + + if !exists { + t.Error("Worker was not stored in server.workers") + } + + if worker.ID != resp.WorkerId { + t.Errorf("Worker.ID = %v, want %v", worker.ID, resp.WorkerId) + } + + if worker.Capacity != int(tt.worker.Capacity) { + t.Errorf("Worker.Capacity = %v, want %v", worker.Capacity, tt.worker.Capacity) + } + + if worker.CurrentLoad != 0 { + t.Errorf("Worker.CurrentLoad = %v, want 0", worker.CurrentLoad) + } + + if len(worker.TaskTypes) != len(tt.worker.TaskTypes) { + t.Errorf("Worker.TaskTypes length = %v, want %v", len(worker.TaskTypes), len(tt.worker.TaskTypes)) + } + } + }) + } +} + +func TestServer_Heartbeat(t *testing.T) { + s := NewServer() + defer s.cancel() + ctx := context.Background() + + // Register a worker first + registerReq := &proto.RegisterWorkerRequest{ + Worker: &proto.Worker{ + TaskTypes: []string{"test-task"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + } + registerResp, err := s.RegisterWorker(ctx, registerReq) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + tests := []struct { + name string + workerID string + currentLoad int32 + wantErr bool + wantCode codes.Code + }{ + { + name: "valid heartbeat", + workerID: registerResp.WorkerId, + currentLoad: 5, + wantErr: false, + }, + { + name: "heartbeat with zero load", + workerID: registerResp.WorkerId, + currentLoad: 0, + wantErr: false, + }, + { + name: "heartbeat with max load", + workerID: registerResp.WorkerId, + currentLoad: 10, + wantErr: false, + }, + { + name: "heartbeat from non-existent worker", + workerID: "non-existent-worker", + currentLoad: 3, + wantErr: true, + wantCode: codes.NotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Record time before heartbeat + timeBefore := time.Now() + + req := &proto.HeartbeatRequest{ + WorkerId: tt.workerID, + CurrentLoad: tt.currentLoad, + } + + resp, err := s.Heartbeat(ctx, req) + + if tt.wantErr { + if err == nil { + t.Error("Heartbeat() expected error, got nil") + return + } + + st, ok := status.FromError(err) + if !ok { + t.Error("Error is not a gRPC status error") + return + } + + if st.Code() != tt.wantCode { + t.Errorf("Heartbeat() error code = %v, want %v", st.Code(), tt.wantCode) + } + } else { + if err != nil { + t.Errorf("Heartbeat() unexpected error = %v", err) + return + } + + if resp == nil { + t.Fatal("Heartbeat() returned nil response") + } + + if !resp.Success { + t.Error("Heartbeat() success = false, want true") + } + + if resp.CurrentLoad != tt.currentLoad { + t.Errorf("Heartbeat() current_load = %v, want %v", resp.CurrentLoad, tt.currentLoad) + } + + // Verify worker's last heartbeat was updated + s.workersMux.RLock() + worker := s.workers[tt.workerID] + s.workersMux.RUnlock() + + if worker.LastHeartbeat.Before(timeBefore) { + t.Error("Worker's LastHeartbeat was not updated") + } + + if worker.CurrentLoad != int(tt.currentLoad) { + t.Errorf("Worker.CurrentLoad = %v, want %v", worker.CurrentLoad, tt.currentLoad) + } + } + }) + } +} + +func TestServer_FetchTask(t *testing.T) { + s := NewServer() + defer s.cancel() + ctx := context.Background() + + // Register a worker + registerReq := &proto.RegisterWorkerRequest{ + Worker: &proto.Worker{ + TaskTypes: []string{"image-processing", "data-export"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + } + registerResp, err := s.RegisterWorker(ctx, registerReq) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + // Submit tasks with different priorities + highPriorityTask, _ := s.SubmitTask(ctx, &proto.SubmitTaskRequest{ + Type: "image-processing", + Payload: []byte("high priority"), + Priority: int32(proto.Priority_HIGH), + MaxRetries: 3, + }) + + mediumPriorityTask, _ := s.SubmitTask(ctx, &proto.SubmitTaskRequest{ + Type: "data-export", + Payload: []byte("medium priority"), + Priority: int32(proto.Priority_MEDIUM), + MaxRetries: 3, + }) + + lowPriorityTask, _ := s.SubmitTask(ctx, &proto.SubmitTaskRequest{ + Type: "image-processing", + Payload: []byte("low priority"), + Priority: int32(proto.Priority_LOW), + MaxRetries: 3, + }) + + tests := []struct { + name string + workerID string + taskTypes []string + wantHasTask bool + wantTaskID string + wantErr bool + wantCode codes.Code + setupFunc func() + }{ + { + name: "fetch high priority task", + workerID: registerResp.WorkerId, + taskTypes: []string{"image-processing", "data-export"}, + wantHasTask: true, + wantTaskID: highPriorityTask.TaskId, + wantErr: false, + }, + { + name: "fetch medium priority task after high priority consumed", + workerID: registerResp.WorkerId, + taskTypes: []string{"data-export"}, + wantHasTask: true, + wantTaskID: mediumPriorityTask.TaskId, + wantErr: false, + }, + { + name: "fetch low priority task", + workerID: registerResp.WorkerId, + taskTypes: []string{"image-processing"}, + wantHasTask: true, + wantTaskID: lowPriorityTask.TaskId, + wantErr: false, + }, + { + name: "no tasks available for worker task types", + workerID: registerResp.WorkerId, + taskTypes: []string{"email-send"}, + wantHasTask: false, + wantErr: false, + }, + { + name: "worker not found", + workerID: "non-existent-worker", + taskTypes: []string{"image-processing"}, + wantHasTask: false, + wantErr: true, + wantCode: codes.NotFound, + }, + { + name: "worker at capacity cannot fetch tasks", + workerID: registerResp.WorkerId, + taskTypes: []string{"image-processing"}, + wantHasTask: false, + wantErr: false, + setupFunc: func() { + // Set worker to full capacity + s.workersMux.Lock() + s.workers[registerResp.WorkerId].CurrentLoad = 10 + s.workersMux.Unlock() + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupFunc != nil { + tt.setupFunc() + } + + req := &proto.FetchTaskRequest{ + WorkerId: tt.workerID, + TaskTypes: tt.taskTypes, + } + + resp, err := s.FetchTask(ctx, req) + + if tt.wantErr { + if err == nil { + t.Error("FetchTask() expected error, got nil") + return + } + + st, ok := status.FromError(err) + if !ok { + t.Error("Error is not a gRPC status error") + return + } + + if st.Code() != tt.wantCode { + t.Errorf("FetchTask() error code = %v, want %v", st.Code(), tt.wantCode) + } + } else { + if err != nil { + t.Errorf("FetchTask() unexpected error = %v", err) + return + } + + if resp == nil { + t.Fatal("FetchTask() returned nil response") + } + + if resp.HasTask != tt.wantHasTask { + t.Errorf("FetchTask() has_task = %v, want %v", resp.HasTask, tt.wantHasTask) + } + + if tt.wantHasTask { + if resp.Task == nil { + t.Fatal("FetchTask() returned nil task when has_task is true") + } + + if resp.Task.Id != tt.wantTaskID { + t.Errorf("FetchTask() task ID = %v, want %v", resp.Task.Id, tt.wantTaskID) + } + + // Verify task status was updated to RUNNING + s.tasksMux.RLock() + task := s.tasks[resp.Task.Id] + s.tasksMux.RUnlock() + + if task.Status != RUNNING { + t.Errorf("Task status = %v, want RUNNING", task.Status) + } + + if task.StartedAt == nil { + t.Error("Task.StartedAt is nil") + } + + if task.WorkerID != tt.workerID { + t.Errorf("Task.WorkerID = %v, want %v", task.WorkerID, tt.workerID) + } + } + } + }) + } +} + +func TestServer_SubmitResult(t *testing.T) { + s := NewServer() + defer s.cancel() + ctx := context.Background() + + // Register a worker + registerReq := &proto.RegisterWorkerRequest{ + Worker: &proto.Worker{ + TaskTypes: []string{"test-task"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + } + registerResp, err := s.RegisterWorker(ctx, registerReq) + if err != nil { + t.Fatalf("Failed to register worker: %v", err) + } + + // Submit and fetch a task for success test + submitReq := &proto.SubmitTaskRequest{ + Type: "test-task", + Payload: []byte("test"), + Priority: int32(proto.Priority_HIGH), + MaxRetries: 3, + } + submitResp, _ := s.SubmitTask(ctx, submitReq) + + fetchReq := &proto.FetchTaskRequest{ + WorkerId: registerResp.WorkerId, + TaskTypes: []string{"test-task"}, + } + s.FetchTask(ctx, fetchReq) + + // Submit another task for retry test + retryTaskResp, _ := s.SubmitTask(ctx, &proto.SubmitTaskRequest{ + Type: "test-task", + Payload: []byte("retry test"), + Priority: int32(proto.Priority_HIGH), + MaxRetries: 3, + }) + s.FetchTask(ctx, fetchReq) + + // Submit task for max retries test + maxRetriesTaskResp, _ := s.SubmitTask(ctx, &proto.SubmitTaskRequest{ + Type: "test-task", + Payload: []byte("max retries test"), + Priority: int32(proto.Priority_HIGH), + MaxRetries: 2, + }) + s.FetchTask(ctx, fetchReq) + + // Manually set retry count to simulate previous failures + s.tasksMux.Lock() + s.tasks[maxRetriesTaskResp.TaskId].RetryCount = 1 + s.tasksMux.Unlock() + + tests := []struct { + name string + taskID string + result []byte + error string + wantErr bool + wantCode codes.Code + verify func(t *testing.T, taskID string) + }{ + { + name: "successful task completion", + taskID: submitResp.TaskId, + result: []byte("success result"), + error: "", + wantErr: false, + verify: func(t *testing.T, taskID string) { + s.tasksMux.RLock() + task := s.tasks[taskID] + s.tasksMux.RUnlock() + + if task.Status != COMPLETED { + t.Errorf("Task status = %v, want COMPLETED", task.Status) + } + + if string(task.Result) != "success result" { + t.Errorf("Task result = %v, want 'success result'", string(task.Result)) + } + + if task.CompletedAt == nil { + t.Error("Task.CompletedAt is nil") + } + + if task.Error != "" { + t.Errorf("Task.Error = %v, want empty string", task.Error) + } + }, + }, + { + name: "task failure with retry available", + taskID: retryTaskResp.TaskId, + result: nil, + error: "processing failed", + wantErr: false, + verify: func(t *testing.T, taskID string) { + s.tasksMux.RLock() + task := s.tasks[taskID] + s.tasksMux.RUnlock() + + if task.Status != PENDING { + t.Errorf("Task status = %v, want PENDING (for retry)", task.Status) + } + + if task.RetryCount != 1 { + t.Errorf("Task retry count = %v, want 1", task.RetryCount) + } + + if task.Error != "processing failed" { + t.Errorf("Task error = %v, want 'processing failed'", task.Error) + } + + if task.StartedAt != nil { + t.Error("Task.StartedAt should be nil after retry reset") + } + + if task.CompletedAt != nil { + t.Error("Task.CompletedAt should be nil after retry reset") + } + + // Verify task was re-added to queue + s.queuesMux.RLock() + queue := s.pendingQueues[task.Priority] + s.queuesMux.RUnlock() + + found := false + for _, qTask := range queue { + if qTask.ID == taskID { + found = true + break + } + } + + if !found { + t.Error("Task was not re-added to pending queue for retry") + } + }, + }, + { + name: "task failure exceeding max retries", + taskID: maxRetriesTaskResp.TaskId, + result: nil, + error: "fatal error", + wantErr: true, + wantCode: codes.DeadlineExceeded, + verify: func(t *testing.T, taskID string) { + s.tasksMux.RLock() + task := s.tasks[taskID] + s.tasksMux.RUnlock() + + if task.Status != FAILED { + t.Errorf("Task status = %v, want FAILED", task.Status) + } + + if task.RetryCount != 2 { + t.Errorf("Task retry count = %v, want 2", task.RetryCount) + } + }, + }, + { + name: "submit result for non-existent task", + taskID: "non-existent-task", + result: []byte("result"), + error: "", + wantErr: true, + wantCode: codes.NotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &proto.SubmitResultRequest{ + TaskId: tt.taskID, + Result: tt.result, + Error: tt.error, + } + + resp, err := s.SubmitResult(ctx, req) + + if tt.wantErr { + if err == nil { + t.Error("SubmitResult() expected error, got nil") + return + } + + st, ok := status.FromError(err) + if !ok { + t.Error("Error is not a gRPC status error") + return + } + + if st.Code() != tt.wantCode { + t.Errorf("SubmitResult() error code = %v, want %v", st.Code(), tt.wantCode) + } + } else { + if err != nil { + t.Errorf("SubmitResult() unexpected error = %v", err) + return + } + + if resp == nil { + t.Fatal("SubmitResult() returned nil response") + } + + if !resp.Success { + t.Error("SubmitResult() success = false, want true") + } + } + + if tt.verify != nil { + tt.verify(t, tt.taskID) + } + }) + } +} + +func TestServer_UtilityFunctions(t *testing.T) { + t.Run("appendTaskToQueue", func(t *testing.T) { + s := NewServer() + defer s.cancel() + + highTask := &Task{ID: "high-1", Priority: HIGH, Status: PENDING} + mediumTask := &Task{ID: "medium-1", Priority: MEDIUM, Status: PENDING} + lowTask := &Task{ID: "low-1", Priority: LOW, Status: PENDING} + + s.appendTaskToQueue(highTask) + s.appendTaskToQueue(mediumTask) + s.appendTaskToQueue(lowTask) + + s.queuesMux.RLock() + defer s.queuesMux.RUnlock() + + if len(s.pendingQueues[HIGH]) != 1 { + t.Errorf("HIGH queue length = %v, want 1", len(s.pendingQueues[HIGH])) + } + + if len(s.pendingQueues[MEDIUM]) != 1 { + t.Errorf("MEDIUM queue length = %v, want 1", len(s.pendingQueues[MEDIUM])) + } + + if len(s.pendingQueues[LOW]) != 1 { + t.Errorf("LOW queue length = %v, want 1", len(s.pendingQueues[LOW])) + } + + if s.pendingQueues[HIGH][0].ID != "high-1" { + t.Errorf("HIGH queue task ID = %v, want 'high-1'", s.pendingQueues[HIGH][0].ID) + } + }) + + t.Run("incrementCurrentLoad", func(t *testing.T) { + s := NewServer() + defer s.cancel() + + // Create a worker + workerID := "test-worker-1" + s.workers[workerID] = &worker.Worker{ + ID: workerID, + Capacity: 10, + CurrentLoad: 5, + } + + s.incrementCurrentLoad(workerID) + + s.workersMux.RLock() + load := s.workers[workerID].CurrentLoad + s.workersMux.RUnlock() + + if load != 6 { + t.Errorf("Worker current load = %v, want 6", load) + } + + // Test with non-existent worker (should not panic) + s.incrementCurrentLoad("non-existent") + }) + + t.Run("decrementCurrentLoad", func(t *testing.T) { + s := NewServer() + defer s.cancel() + + // Create a worker + workerID := "test-worker-2" + s.workers[workerID] = &worker.Worker{ + ID: workerID, + Capacity: 10, + CurrentLoad: 5, + } + + s.decrementCurrentLoad(workerID) + + s.workersMux.RLock() + load := s.workers[workerID].CurrentLoad + s.workersMux.RUnlock() + + if load != 4 { + t.Errorf("Worker current load = %v, want 4", load) + } + + // Test with non-existent worker (should not panic) + s.decrementCurrentLoad("non-existent") + }) + + t.Run("findTask", func(t *testing.T) { + s := NewServer() + defer s.cancel() + + testTask := &Task{ID: "test-task-1", Type: "test"} + s.tasks["test-task-1"] = testTask + + // Test finding existing task + found, err := s.findTask("test-task-1") + if err != nil { + t.Errorf("findTask() unexpected error = %v", err) + } + if found.ID != "test-task-1" { + t.Errorf("findTask() task ID = %v, want 'test-task-1'", found.ID) + } + + // Test finding non-existent task + _, err = s.findTask("non-existent") + if err == nil { + t.Error("findTask() expected error for non-existent task, got nil") + } + + st, ok := status.FromError(err) + if !ok { + t.Error("Error is not a gRPC status error") + } + if st.Code() != codes.NotFound { + t.Errorf("findTask() error code = %v, want NotFound", st.Code()) + } + }) + + t.Run("findWorker", func(t *testing.T) { + s := NewServer() + defer s.cancel() + + testWorker := &worker.Worker{ID: "test-worker-1"} + s.workers["test-worker-1"] = testWorker + + // Test finding existing worker + found, err := s.findWorker("test-worker-1") + if err != nil { + t.Errorf("findWorker() unexpected error = %v", err) + } + if found.ID != "test-worker-1" { + t.Errorf("findWorker() worker ID = %v, want 'test-worker-1'", found.ID) + } + + // Test finding non-existent worker + _, err = s.findWorker("non-existent") + if err == nil { + t.Error("findWorker() expected error for non-existent worker, got nil") + } + + st, ok := status.FromError(err) + if !ok { + t.Error("Error is not a gRPC status error") + } + if st.Code() != codes.NotFound { + t.Errorf("findWorker() error code = %v, want NotFound", st.Code()) + } + }) +} + +func TestServer_ConcurrentWorkerOperations(t *testing.T) { + s := NewServer() + defer s.cancel() + ctx := context.Background() + + // Test concurrent worker registrations + const numWorkers = 50 + results := make(chan *proto.RegisterWorkerResponse, numWorkers) + errors := make(chan error, numWorkers) + + for i := range numWorkers { + go func(idx int) { + req := &proto.RegisterWorkerRequest{ + Worker: &proto.Worker{ + TaskTypes: []string{"test-task"}, + Capacity: 10, + Metadata: map[string]string{ + "address": "localhost:8080", + }, + }, + } + resp, err := s.RegisterWorker(ctx, req) + if err != nil { + errors <- err + } else { + results <- resp + } + }(i) + } + + // Collect results + workerIDs := make(map[string]bool) + for range numWorkers { + select { + case resp := <-results: + if workerIDs[resp.WorkerId] { + t.Errorf("Duplicate worker ID generated: %s", resp.WorkerId) + } + workerIDs[resp.WorkerId] = true + case err := <-errors: + t.Errorf("Concurrent RegisterWorker failed: %v", err) + } + } + + // Verify all workers were stored + s.workersMux.RLock() + storedCount := len(s.workers) + s.workersMux.RUnlock() + + if storedCount != numWorkers { + t.Errorf("Expected %d workers stored, got %d", numWorkers, storedCount) + } +} diff --git a/internal/worker/worker.go b/internal/worker/worker.go new file mode 100644 index 0000000..efd7667 --- /dev/null +++ b/internal/worker/worker.go @@ -0,0 +1,40 @@ +package worker + +import ( + "time" + + "github.com/google/uuid" + "github.com/mateusmlo/taskqueue/proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type Worker struct { + ID string + Address string + RegisteredAt time.Time + LastHeartbeat time.Time + TaskTypes []string + Capacity int + CurrentLoad int + Metadata map[string]string +} + +// FromProtoWorker initializes a Worker instance from a proto.Worker message (server generates ID) +func (w *Worker) FromProtoWorker(pw *proto.Worker) error { + uuid, err := uuid.NewV7() + if err != nil { + return status.Errorf(codes.Internal, "failed to generate worker UUID: %v", err) + } + + w.ID = uuid.String() + w.TaskTypes = pw.TaskTypes + w.Address = pw.Metadata["address"] + w.Capacity = int(pw.Capacity) + w.CurrentLoad = 0 + w.Metadata = pw.Metadata + w.RegisteredAt = time.Now() + w.LastHeartbeat = time.Now() + + return nil +} diff --git a/proto/taskqueue.pb.go b/proto/taskqueue.pb.go index db48ebe..0b75982 100644 --- a/proto/taskqueue.pb.go +++ b/proto/taskqueue.pb.go @@ -544,30 +544,32 @@ func (x *GetTaskResultResponse) GetTask() *Task { // ============================================ // WORKER RPCS // ============================================ -type RegisterWorkerRequest struct { +type Worker struct { state protoimpl.MessageState `protogen:"open.v1"` WorkerId string `protobuf:"bytes,1,opt,name=worker_id,json=workerId,proto3" json:"worker_id,omitempty"` TaskTypes []string `protobuf:"bytes,2,rep,name=task_types,json=taskTypes,proto3" json:"task_types,omitempty"` Capacity int32 `protobuf:"varint,3,opt,name=capacity,proto3" json:"capacity,omitempty"` - Metadata map[string]string `protobuf:"bytes,4,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + CurrentLoad int32 `protobuf:"varint,4,opt,name=current_load,json=currentLoad,proto3" json:"current_load,omitempty"` + LastHeartbeat int32 `protobuf:"varint,5,opt,name=last_heartbeat,json=lastHeartbeat,proto3" json:"last_heartbeat,omitempty"` + Metadata map[string]string `protobuf:"bytes,6,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *RegisterWorkerRequest) Reset() { - *x = RegisterWorkerRequest{} +func (x *Worker) Reset() { + *x = Worker{} mi := &file_proto_taskqueue_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *RegisterWorkerRequest) String() string { +func (x *Worker) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RegisterWorkerRequest) ProtoMessage() {} +func (*Worker) ProtoMessage() {} -func (x *RegisterWorkerRequest) ProtoReflect() protoreflect.Message { +func (x *Worker) ProtoReflect() protoreflect.Message { mi := &file_proto_taskqueue_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -579,50 +581,108 @@ func (x *RegisterWorkerRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RegisterWorkerRequest.ProtoReflect.Descriptor instead. -func (*RegisterWorkerRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use Worker.ProtoReflect.Descriptor instead. +func (*Worker) Descriptor() ([]byte, []int) { return file_proto_taskqueue_proto_rawDescGZIP(), []int{7} } -func (x *RegisterWorkerRequest) GetWorkerId() string { +func (x *Worker) GetWorkerId() string { if x != nil { return x.WorkerId } return "" } -func (x *RegisterWorkerRequest) GetTaskTypes() []string { +func (x *Worker) GetTaskTypes() []string { if x != nil { return x.TaskTypes } return nil } -func (x *RegisterWorkerRequest) GetCapacity() int32 { +func (x *Worker) GetCapacity() int32 { if x != nil { return x.Capacity } return 0 } -func (x *RegisterWorkerRequest) GetMetadata() map[string]string { +func (x *Worker) GetCurrentLoad() int32 { + if x != nil { + return x.CurrentLoad + } + return 0 +} + +func (x *Worker) GetLastHeartbeat() int32 { + if x != nil { + return x.LastHeartbeat + } + return 0 +} + +func (x *Worker) GetMetadata() map[string]string { if x != nil { return x.Metadata } return nil } +type RegisterWorkerRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Worker *Worker `protobuf:"bytes,1,opt,name=worker,proto3" json:"worker,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RegisterWorkerRequest) Reset() { + *x = RegisterWorkerRequest{} + mi := &file_proto_taskqueue_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RegisterWorkerRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RegisterWorkerRequest) ProtoMessage() {} + +func (x *RegisterWorkerRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_taskqueue_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RegisterWorkerRequest.ProtoReflect.Descriptor instead. +func (*RegisterWorkerRequest) Descriptor() ([]byte, []int) { + return file_proto_taskqueue_proto_rawDescGZIP(), []int{8} +} + +func (x *RegisterWorkerRequest) GetWorker() *Worker { + if x != nil { + return x.Worker + } + return nil +} + type RegisterWorkerResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` - Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + WorkerId string `protobuf:"bytes,2,opt,name=worker_id,json=workerId,proto3" json:"worker_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *RegisterWorkerResponse) Reset() { *x = RegisterWorkerResponse{} - mi := &file_proto_taskqueue_proto_msgTypes[8] + mi := &file_proto_taskqueue_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -634,7 +694,7 @@ func (x *RegisterWorkerResponse) String() string { func (*RegisterWorkerResponse) ProtoMessage() {} func (x *RegisterWorkerResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[8] + mi := &file_proto_taskqueue_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -647,7 +707,7 @@ func (x *RegisterWorkerResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RegisterWorkerResponse.ProtoReflect.Descriptor instead. func (*RegisterWorkerResponse) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{8} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{9} } func (x *RegisterWorkerResponse) GetSuccess() bool { @@ -657,9 +717,9 @@ func (x *RegisterWorkerResponse) GetSuccess() bool { return false } -func (x *RegisterWorkerResponse) GetMessage() string { +func (x *RegisterWorkerResponse) GetWorkerId() string { if x != nil { - return x.Message + return x.WorkerId } return "" } @@ -674,7 +734,7 @@ type HeartbeatRequest struct { func (x *HeartbeatRequest) Reset() { *x = HeartbeatRequest{} - mi := &file_proto_taskqueue_proto_msgTypes[9] + mi := &file_proto_taskqueue_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -686,7 +746,7 @@ func (x *HeartbeatRequest) String() string { func (*HeartbeatRequest) ProtoMessage() {} func (x *HeartbeatRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[9] + mi := &file_proto_taskqueue_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -699,7 +759,7 @@ func (x *HeartbeatRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use HeartbeatRequest.ProtoReflect.Descriptor instead. func (*HeartbeatRequest) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{9} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{10} } func (x *HeartbeatRequest) GetWorkerId() string { @@ -719,13 +779,14 @@ func (x *HeartbeatRequest) GetCurrentLoad() int32 { type HeartbeatResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + CurrentLoad int32 `protobuf:"varint,2,opt,name=current_load,json=currentLoad,proto3" json:"current_load,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *HeartbeatResponse) Reset() { *x = HeartbeatResponse{} - mi := &file_proto_taskqueue_proto_msgTypes[10] + mi := &file_proto_taskqueue_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -737,7 +798,7 @@ func (x *HeartbeatResponse) String() string { func (*HeartbeatResponse) ProtoMessage() {} func (x *HeartbeatResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[10] + mi := &file_proto_taskqueue_proto_msgTypes[11] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -750,7 +811,7 @@ func (x *HeartbeatResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use HeartbeatResponse.ProtoReflect.Descriptor instead. func (*HeartbeatResponse) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{10} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{11} } func (x *HeartbeatResponse) GetSuccess() bool { @@ -760,6 +821,13 @@ func (x *HeartbeatResponse) GetSuccess() bool { return false } +func (x *HeartbeatResponse) GetCurrentLoad() int32 { + if x != nil { + return x.CurrentLoad + } + return 0 +} + type FetchTaskRequest struct { state protoimpl.MessageState `protogen:"open.v1"` WorkerId string `protobuf:"bytes,1,opt,name=worker_id,json=workerId,proto3" json:"worker_id,omitempty"` @@ -770,7 +838,7 @@ type FetchTaskRequest struct { func (x *FetchTaskRequest) Reset() { *x = FetchTaskRequest{} - mi := &file_proto_taskqueue_proto_msgTypes[11] + mi := &file_proto_taskqueue_proto_msgTypes[12] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -782,7 +850,7 @@ func (x *FetchTaskRequest) String() string { func (*FetchTaskRequest) ProtoMessage() {} func (x *FetchTaskRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[11] + mi := &file_proto_taskqueue_proto_msgTypes[12] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -795,7 +863,7 @@ func (x *FetchTaskRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use FetchTaskRequest.ProtoReflect.Descriptor instead. func (*FetchTaskRequest) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{11} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{12} } func (x *FetchTaskRequest) GetWorkerId() string { @@ -822,7 +890,7 @@ type FetchTaskResponse struct { func (x *FetchTaskResponse) Reset() { *x = FetchTaskResponse{} - mi := &file_proto_taskqueue_proto_msgTypes[12] + mi := &file_proto_taskqueue_proto_msgTypes[13] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -834,7 +902,7 @@ func (x *FetchTaskResponse) String() string { func (*FetchTaskResponse) ProtoMessage() {} func (x *FetchTaskResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[12] + mi := &file_proto_taskqueue_proto_msgTypes[13] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -847,7 +915,7 @@ func (x *FetchTaskResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use FetchTaskResponse.ProtoReflect.Descriptor instead. func (*FetchTaskResponse) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{12} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{13} } func (x *FetchTaskResponse) GetTask() *Task { @@ -868,14 +936,14 @@ type SubmitResultRequest struct { state protoimpl.MessageState `protogen:"open.v1"` TaskId string `protobuf:"bytes,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` Result []byte `protobuf:"bytes,2,opt,name=result,proto3" json:"result,omitempty"` - Error *string `protobuf:"bytes,3,opt,name=error,proto3,oneof" json:"error,omitempty"` + Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SubmitResultRequest) Reset() { *x = SubmitResultRequest{} - mi := &file_proto_taskqueue_proto_msgTypes[13] + mi := &file_proto_taskqueue_proto_msgTypes[14] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -887,7 +955,7 @@ func (x *SubmitResultRequest) String() string { func (*SubmitResultRequest) ProtoMessage() {} func (x *SubmitResultRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[13] + mi := &file_proto_taskqueue_proto_msgTypes[14] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -900,7 +968,7 @@ func (x *SubmitResultRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SubmitResultRequest.ProtoReflect.Descriptor instead. func (*SubmitResultRequest) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{13} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{14} } func (x *SubmitResultRequest) GetTaskId() string { @@ -918,8 +986,8 @@ func (x *SubmitResultRequest) GetResult() []byte { } func (x *SubmitResultRequest) GetError() string { - if x != nil && x.Error != nil { - return *x.Error + if x != nil { + return x.Error } return "" } @@ -927,13 +995,15 @@ func (x *SubmitResultRequest) GetError() string { type SubmitResultResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + Result []byte `protobuf:"bytes,2,opt,name=result,proto3" json:"result,omitempty"` + Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *SubmitResultResponse) Reset() { *x = SubmitResultResponse{} - mi := &file_proto_taskqueue_proto_msgTypes[14] + mi := &file_proto_taskqueue_proto_msgTypes[15] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -945,7 +1015,7 @@ func (x *SubmitResultResponse) String() string { func (*SubmitResultResponse) ProtoMessage() {} func (x *SubmitResultResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_taskqueue_proto_msgTypes[14] + mi := &file_proto_taskqueue_proto_msgTypes[15] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -958,7 +1028,7 @@ func (x *SubmitResultResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SubmitResultResponse.ProtoReflect.Descriptor instead. func (*SubmitResultResponse) Descriptor() ([]byte, []int) { - return file_proto_taskqueue_proto_rawDescGZIP(), []int{14} + return file_proto_taskqueue_proto_rawDescGZIP(), []int{15} } func (x *SubmitResultResponse) GetSuccess() bool { @@ -968,6 +1038,20 @@ func (x *SubmitResultResponse) GetSuccess() bool { return false } +func (x *SubmitResultResponse) GetResult() []byte { + if x != nil { + return x.Result + } + return nil +} + +func (x *SubmitResultResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + var File_proto_taskqueue_proto protoreflect.FileDescriptor const file_proto_taskqueue_proto_rawDesc = "" + @@ -1007,38 +1091,44 @@ const file_proto_taskqueue_proto_rawDesc = "" + "\x14GetTaskResultRequest\x12\x17\n" + "\atask_id\x18\x01 \x01(\tR\x06taskId\"<\n" + "\x15GetTaskResultResponse\x12#\n" + - "\x04task\x18\x01 \x01(\v2\x0f.taskqueue.TaskR\x04task\"\xf8\x01\n" + - "\x15RegisterWorkerRequest\x12\x1b\n" + + "\x04task\x18\x01 \x01(\v2\x0f.taskqueue.TaskR\x04task\"\xa4\x02\n" + + "\x06Worker\x12\x1b\n" + "\tworker_id\x18\x01 \x01(\tR\bworkerId\x12\x1d\n" + "\n" + "task_types\x18\x02 \x03(\tR\ttaskTypes\x12\x1a\n" + - "\bcapacity\x18\x03 \x01(\x05R\bcapacity\x12J\n" + - "\bmetadata\x18\x04 \x03(\v2..taskqueue.RegisterWorkerRequest.MetadataEntryR\bmetadata\x1a;\n" + + "\bcapacity\x18\x03 \x01(\x05R\bcapacity\x12!\n" + + "\fcurrent_load\x18\x04 \x01(\x05R\vcurrentLoad\x12%\n" + + "\x0elast_heartbeat\x18\x05 \x01(\x05R\rlastHeartbeat\x12;\n" + + "\bmetadata\x18\x06 \x03(\v2\x1f.taskqueue.Worker.MetadataEntryR\bmetadata\x1a;\n" + "\rMetadataEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"L\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"B\n" + + "\x15RegisterWorkerRequest\x12)\n" + + "\x06worker\x18\x01 \x01(\v2\x11.taskqueue.WorkerR\x06worker\"O\n" + "\x16RegisterWorkerResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x18\n" + - "\amessage\x18\x02 \x01(\tR\amessage\"R\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1b\n" + + "\tworker_id\x18\x02 \x01(\tR\bworkerId\"R\n" + "\x10HeartbeatRequest\x12\x1b\n" + "\tworker_id\x18\x01 \x01(\tR\bworkerId\x12!\n" + - "\fcurrent_load\x18\x02 \x01(\x05R\vcurrentLoad\"-\n" + + "\fcurrent_load\x18\x02 \x01(\x05R\vcurrentLoad\"P\n" + "\x11HeartbeatResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess\"N\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12!\n" + + "\fcurrent_load\x18\x02 \x01(\x05R\vcurrentLoad\"N\n" + "\x10FetchTaskRequest\x12\x1b\n" + "\tworker_id\x18\x01 \x01(\tR\bworkerId\x12\x1d\n" + "\n" + "task_types\x18\x02 \x03(\tR\ttaskTypes\"S\n" + "\x11FetchTaskResponse\x12#\n" + "\x04task\x18\x01 \x01(\v2\x0f.taskqueue.TaskR\x04task\x12\x19\n" + - "\bhas_task\x18\x02 \x01(\bR\ahasTask\"k\n" + + "\bhas_task\x18\x02 \x01(\bR\ahasTask\"\\\n" + "\x13SubmitResultRequest\x12\x17\n" + "\atask_id\x18\x01 \x01(\tR\x06taskId\x12\x16\n" + - "\x06result\x18\x02 \x01(\fR\x06result\x12\x19\n" + - "\x05error\x18\x03 \x01(\tH\x00R\x05error\x88\x01\x01B\b\n" + - "\x06_error\"0\n" + + "\x06result\x18\x02 \x01(\fR\x06result\x12\x14\n" + + "\x05error\x18\x03 \x01(\tR\x05error\"^\n" + "\x14SubmitResultResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess*A\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x16\n" + + "\x06result\x18\x02 \x01(\fR\x06result\x12\x14\n" + + "\x05error\x18\x04 \x01(\tR\x05error*A\n" + "\n" + "TaskStatus\x12\v\n" + "\aPENDING\x10\x00\x12\v\n" + @@ -1075,7 +1165,7 @@ func file_proto_taskqueue_proto_rawDescGZIP() []byte { } var file_proto_taskqueue_proto_enumTypes = make([]protoimpl.EnumInfo, 2) -var file_proto_taskqueue_proto_msgTypes = make([]protoimpl.MessageInfo, 16) +var file_proto_taskqueue_proto_msgTypes = make([]protoimpl.MessageInfo, 17) var file_proto_taskqueue_proto_goTypes = []any{ (TaskStatus)(0), // 0: taskqueue.TaskStatus (Priority)(0), // 1: taskqueue.Priority @@ -1086,46 +1176,48 @@ var file_proto_taskqueue_proto_goTypes = []any{ (*GetTaskStatusResponse)(nil), // 6: taskqueue.GetTaskStatusResponse (*GetTaskResultRequest)(nil), // 7: taskqueue.GetTaskResultRequest (*GetTaskResultResponse)(nil), // 8: taskqueue.GetTaskResultResponse - (*RegisterWorkerRequest)(nil), // 9: taskqueue.RegisterWorkerRequest - (*RegisterWorkerResponse)(nil), // 10: taskqueue.RegisterWorkerResponse - (*HeartbeatRequest)(nil), // 11: taskqueue.HeartbeatRequest - (*HeartbeatResponse)(nil), // 12: taskqueue.HeartbeatResponse - (*FetchTaskRequest)(nil), // 13: taskqueue.FetchTaskRequest - (*FetchTaskResponse)(nil), // 14: taskqueue.FetchTaskResponse - (*SubmitResultRequest)(nil), // 15: taskqueue.SubmitResultRequest - (*SubmitResultResponse)(nil), // 16: taskqueue.SubmitResultResponse - nil, // 17: taskqueue.RegisterWorkerRequest.MetadataEntry - (*timestamppb.Timestamp)(nil), // 18: google.protobuf.Timestamp + (*Worker)(nil), // 9: taskqueue.Worker + (*RegisterWorkerRequest)(nil), // 10: taskqueue.RegisterWorkerRequest + (*RegisterWorkerResponse)(nil), // 11: taskqueue.RegisterWorkerResponse + (*HeartbeatRequest)(nil), // 12: taskqueue.HeartbeatRequest + (*HeartbeatResponse)(nil), // 13: taskqueue.HeartbeatResponse + (*FetchTaskRequest)(nil), // 14: taskqueue.FetchTaskRequest + (*FetchTaskResponse)(nil), // 15: taskqueue.FetchTaskResponse + (*SubmitResultRequest)(nil), // 16: taskqueue.SubmitResultRequest + (*SubmitResultResponse)(nil), // 17: taskqueue.SubmitResultResponse + nil, // 18: taskqueue.Worker.MetadataEntry + (*timestamppb.Timestamp)(nil), // 19: google.protobuf.Timestamp } var file_proto_taskqueue_proto_depIdxs = []int32{ 1, // 0: taskqueue.Task.priority:type_name -> taskqueue.Priority - 18, // 1: taskqueue.Task.created_at:type_name -> google.protobuf.Timestamp - 18, // 2: taskqueue.Task.started_at:type_name -> google.protobuf.Timestamp - 18, // 3: taskqueue.Task.completed_at:type_name -> google.protobuf.Timestamp + 19, // 1: taskqueue.Task.created_at:type_name -> google.protobuf.Timestamp + 19, // 2: taskqueue.Task.started_at:type_name -> google.protobuf.Timestamp + 19, // 3: taskqueue.Task.completed_at:type_name -> google.protobuf.Timestamp 0, // 4: taskqueue.Task.status:type_name -> taskqueue.TaskStatus 0, // 5: taskqueue.GetTaskStatusResponse.status:type_name -> taskqueue.TaskStatus 2, // 6: taskqueue.GetTaskResultResponse.task:type_name -> taskqueue.Task - 17, // 7: taskqueue.RegisterWorkerRequest.metadata:type_name -> taskqueue.RegisterWorkerRequest.MetadataEntry - 2, // 8: taskqueue.FetchTaskResponse.task:type_name -> taskqueue.Task - 3, // 9: taskqueue.TaskQueue.SubmitTask:input_type -> taskqueue.SubmitTaskRequest - 7, // 10: taskqueue.TaskQueue.GetTaskResult:input_type -> taskqueue.GetTaskResultRequest - 5, // 11: taskqueue.TaskQueue.GetTaskStatus:input_type -> taskqueue.GetTaskStatusRequest - 9, // 12: taskqueue.WorkerService.RegisterWorker:input_type -> taskqueue.RegisterWorkerRequest - 13, // 13: taskqueue.WorkerService.FetchTask:input_type -> taskqueue.FetchTaskRequest - 15, // 14: taskqueue.WorkerService.SubmitResult:input_type -> taskqueue.SubmitResultRequest - 11, // 15: taskqueue.WorkerService.Heartbeat:input_type -> taskqueue.HeartbeatRequest - 4, // 16: taskqueue.TaskQueue.SubmitTask:output_type -> taskqueue.SubmitTaskResponse - 8, // 17: taskqueue.TaskQueue.GetTaskResult:output_type -> taskqueue.GetTaskResultResponse - 6, // 18: taskqueue.TaskQueue.GetTaskStatus:output_type -> taskqueue.GetTaskStatusResponse - 10, // 19: taskqueue.WorkerService.RegisterWorker:output_type -> taskqueue.RegisterWorkerResponse - 14, // 20: taskqueue.WorkerService.FetchTask:output_type -> taskqueue.FetchTaskResponse - 16, // 21: taskqueue.WorkerService.SubmitResult:output_type -> taskqueue.SubmitResultResponse - 12, // 22: taskqueue.WorkerService.Heartbeat:output_type -> taskqueue.HeartbeatResponse - 16, // [16:23] is the sub-list for method output_type - 9, // [9:16] is the sub-list for method input_type - 9, // [9:9] is the sub-list for extension type_name - 9, // [9:9] is the sub-list for extension extendee - 0, // [0:9] is the sub-list for field type_name + 18, // 7: taskqueue.Worker.metadata:type_name -> taskqueue.Worker.MetadataEntry + 9, // 8: taskqueue.RegisterWorkerRequest.worker:type_name -> taskqueue.Worker + 2, // 9: taskqueue.FetchTaskResponse.task:type_name -> taskqueue.Task + 3, // 10: taskqueue.TaskQueue.SubmitTask:input_type -> taskqueue.SubmitTaskRequest + 7, // 11: taskqueue.TaskQueue.GetTaskResult:input_type -> taskqueue.GetTaskResultRequest + 5, // 12: taskqueue.TaskQueue.GetTaskStatus:input_type -> taskqueue.GetTaskStatusRequest + 10, // 13: taskqueue.WorkerService.RegisterWorker:input_type -> taskqueue.RegisterWorkerRequest + 14, // 14: taskqueue.WorkerService.FetchTask:input_type -> taskqueue.FetchTaskRequest + 16, // 15: taskqueue.WorkerService.SubmitResult:input_type -> taskqueue.SubmitResultRequest + 12, // 16: taskqueue.WorkerService.Heartbeat:input_type -> taskqueue.HeartbeatRequest + 4, // 17: taskqueue.TaskQueue.SubmitTask:output_type -> taskqueue.SubmitTaskResponse + 8, // 18: taskqueue.TaskQueue.GetTaskResult:output_type -> taskqueue.GetTaskResultResponse + 6, // 19: taskqueue.TaskQueue.GetTaskStatus:output_type -> taskqueue.GetTaskStatusResponse + 11, // 20: taskqueue.WorkerService.RegisterWorker:output_type -> taskqueue.RegisterWorkerResponse + 15, // 21: taskqueue.WorkerService.FetchTask:output_type -> taskqueue.FetchTaskResponse + 17, // 22: taskqueue.WorkerService.SubmitResult:output_type -> taskqueue.SubmitResultResponse + 13, // 23: taskqueue.WorkerService.Heartbeat:output_type -> taskqueue.HeartbeatResponse + 17, // [17:24] is the sub-list for method output_type + 10, // [10:17] is the sub-list for method input_type + 10, // [10:10] is the sub-list for extension type_name + 10, // [10:10] is the sub-list for extension extendee + 0, // [0:10] is the sub-list for field type_name } func init() { file_proto_taskqueue_proto_init() } @@ -1134,14 +1226,13 @@ func file_proto_taskqueue_proto_init() { return } file_proto_taskqueue_proto_msgTypes[0].OneofWrappers = []any{} - file_proto_taskqueue_proto_msgTypes[13].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_taskqueue_proto_rawDesc), len(file_proto_taskqueue_proto_rawDesc)), NumEnums: 2, - NumMessages: 16, + NumMessages: 17, NumExtensions: 0, NumServices: 2, }, diff --git a/proto/taskqueue.proto b/proto/taskqueue.proto index b3e25e6..bc6a4ba 100644 --- a/proto/taskqueue.proto +++ b/proto/taskqueue.proto @@ -71,16 +71,22 @@ message GetTaskResultResponse { // ============================================ // WORKER RPCS // ============================================ +message Worker { + string worker_id = 1; + repeated string task_types = 2; + int32 capacity = 3; + int32 current_load = 4; + int32 last_heartbeat = 5; + map metadata = 6; +} + message RegisterWorkerRequest { - string worker_id = 1; - repeated string task_types = 2; - int32 capacity = 3; - map metadata = 4; + Worker worker = 1; } message RegisterWorkerResponse { bool success = 1; - string message = 2; + string worker_id = 2; } message HeartbeatRequest { @@ -90,6 +96,7 @@ message HeartbeatRequest { message HeartbeatResponse { bool success = 1; + int32 current_load = 2; } message FetchTaskRequest { @@ -105,11 +112,13 @@ message FetchTaskResponse { message SubmitResultRequest { string task_id = 1; bytes result = 2; - optional string error = 3; + string error = 3; } message SubmitResultResponse { bool success = 1; + bytes result = 2; + string error = 4; } // ============================================