Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,6 @@ build/
tmp/
temp/
*.tmp

# local test certificates
cert/
12 changes: 5 additions & 7 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ import (
"time"

"github.com/google/uuid"
"github.com/mateusmlo/taskqueue/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/mateusmlo/taskqueue/internal/worker"
"github.com/mateusmlo/taskqueue/proto"
)

type Priority int
Expand All @@ -39,7 +37,7 @@ type Server struct {
pendingQueues map[Priority][]*Task
queuesMux sync.RWMutex

workers map[string]*worker.Worker
workers map[string]*WorkerInfo
workersMux sync.RWMutex

ctx context.Context
Expand Down Expand Up @@ -73,7 +71,7 @@ func NewServer() *Server {
return &Server{
tasks: make(map[string]*Task),
pendingQueues: make(map[Priority][]*Task),
workers: make(map[string]*worker.Worker),
workers: make(map[string]*WorkerInfo),
ctx: ctx,
cancel: cancel,
}
Expand Down Expand Up @@ -158,7 +156,7 @@ func (s *Server) GetTaskResult(ctx context.Context, req *proto.GetTaskResultRequ

// RegisterWorker handles worker registration requests
func (s *Server) RegisterWorker(ctx context.Context, req *proto.RegisterWorkerRequest) (*proto.RegisterWorkerResponse, error) {
var newWorker worker.Worker
var newWorker WorkerInfo
newWorker.FromProtoWorker(req.Worker)

s.workersMux.Lock()
Expand Down Expand Up @@ -300,7 +298,7 @@ func (s *Server) findTask(taskID string) (*Task, error) {
}

// findWorker retrieves a worker by its ID, returning an error if not found
func (s *Server) findWorker(workerID string) (*worker.Worker, error) {
func (s *Server) findWorker(workerID string) (*WorkerInfo, error) {
s.workersMux.RLock()
defer s.workersMux.RUnlock()

Expand Down
7 changes: 3 additions & 4 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"testing"
"time"

"github.com/mateusmlo/taskqueue/internal/worker"
"github.com/mateusmlo/taskqueue/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -1247,7 +1246,7 @@ func TestServer_UtilityFunctions(t *testing.T) {

// Create a worker
workerID := "test-worker-1"
s.workers[workerID] = &worker.Worker{
s.workers[workerID] = &WorkerInfo{
ID: workerID,
Capacity: 10,
CurrentLoad: 5,
Expand All @@ -1273,7 +1272,7 @@ func TestServer_UtilityFunctions(t *testing.T) {

// Create a worker
workerID := "test-worker-2"
s.workers[workerID] = &worker.Worker{
s.workers[workerID] = &WorkerInfo{
ID: workerID,
Capacity: 10,
CurrentLoad: 5,
Expand Down Expand Up @@ -1328,7 +1327,7 @@ func TestServer_UtilityFunctions(t *testing.T) {
s := NewServer()
defer s.cancel()

testWorker := &worker.Worker{ID: "test-worker-1"}
testWorker := &WorkerInfo{ID: "test-worker-1"}
s.workers["test-worker-1"] = testWorker

// Test finding existing worker
Expand Down
40 changes: 40 additions & 0 deletions internal/server/worker_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package server

import (
"time"

"github.com/google/uuid"
"github.com/mateusmlo/taskqueue/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// WorkerInfo tracks a registered worker (server-side only)
type WorkerInfo struct {
ID string
Address string
TaskTypes []string
Capacity int
CurrentLoad int
RegisteredAt time.Time
LastHeartbeat time.Time
Metadata map[string]string
}

func (wi *WorkerInfo) FromProtoWorker(pw *proto.Worker) error {
uuid, err := uuid.NewV7()
if err != nil {
return status.Errorf(codes.Internal, "failed to generate worker UUID: %v", err)
}

wi.ID = uuid.String()
wi.TaskTypes = pw.TaskTypes
wi.Address = pw.Metadata["address"]
wi.Capacity = int(pw.Capacity)
wi.CurrentLoad = 0
wi.Metadata = pw.Metadata
wi.RegisteredAt = time.Now()
wi.LastHeartbeat = time.Now()

return nil
}
249 changes: 224 additions & 25 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,239 @@
package worker

import (
"context"
"log"
"sync"
"time"

"github.com/google/uuid"
"github.com/mateusmlo/taskqueue/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type Worker struct {
ID string
Address string
RegisteredAt time.Time
LastHeartbeat time.Time
TaskTypes []string
Capacity int
CurrentLoad int
Metadata map[string]string
}

// FromProtoWorker initializes a Worker instance from a proto.Worker message (server generates ID)
func (w *Worker) FromProtoWorker(pw *proto.Worker) error {
uuid, err := uuid.NewV7()
serverAddr string
conn *grpc.ClientConn
client proto.WorkerServiceClient

id string
capacity int

handlers map[string]TaskHandler
currentLoad int
loadMux sync.RWMutex

ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}

type TaskHandler interface {
Handle(ctx context.Context, payload []byte) ([]byte, error)
}

func NewWorker(serverAddr string, capacity int) *Worker {
ctx, cancel := context.WithCancel(context.Background())

return &Worker{
serverAddr: serverAddr,
capacity: capacity,
handlers: make(map[string]TaskHandler),
ctx: ctx,
cancel: cancel,
}
}

func (w *Worker) RegisterHandler(taskType string, handler TaskHandler) {
w.handlers[taskType] = handler
}

func (w *Worker) Start() error {
tcr, err := credentials.NewClientTLSFromFile("./cert/server.crt", "localhost")
if err != nil {
return status.Errorf(codes.Internal, "failed to generate worker UUID: %v", err)
return err
}

conn, err := grpc.NewClient(w.serverAddr, grpc.WithTransportCredentials(tcr))
if err != nil {
return err
}

w.conn = conn
w.client = proto.NewWorkerServiceClient(w.conn)

if err := w.register(); err != nil {
w.conn.Close()
return err
}

w.wg.Add(2)
go w.heartbeatLoop()
go w.fetchLoop()

return nil
}

func (w *Worker) Stop() {
w.cancel()
w.wg.Wait()

if w.conn != nil {
if err := w.conn.Close(); err != nil {
log.Printf("Error closing gRPC connection: %v", err)
}

w.conn = nil
}
}

func (w *Worker) heartbeatLoop() {
defer w.wg.Done()

ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
req := w.buildHeartbeatRequest()
_, err := w.client.Heartbeat(w.ctx, req)
if err != nil {
log.Printf("Worker heartbeat error: %v", err)
}
case <-w.ctx.Done():
return
}
}
}

func (w *Worker) fetchLoop() {
defer w.wg.Done()

ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
req := w.buildFetchTasksRequest()
res, err := w.client.FetchTask(w.ctx, req)
if err != nil {
log.Printf("Worker fetch task error: %v", err)
continue
}

if !res.HasTask {
continue
}

handler, exists := w.handlers[res.Task.Type]
if !exists {
log.Printf("No handler registered for task type: %s", res.Task.Type)
continue
}

w.incrementLoad()

handleTask := w.getTaskHandler(handler)

go handleTask(res.Task)
case <-w.ctx.Done():
return
}
}
}

func (w *Worker) getTaskHandler(handler TaskHandler) func(task *proto.Task) {
return func(task *proto.Task) {
defer w.decrementLoad()

result, err := handler.Handle(w.ctx, task.Payload)
submitReq := &proto.SubmitResultRequest{
TaskId: task.Id,
}
if err != nil {
submitReq.Error = err.Error()
submitReq.Result = nil
} else {
submitReq.Error = ""
submitReq.Result = result
}

w.ID = uuid.String()
w.TaskTypes = pw.TaskTypes
w.Address = pw.Metadata["address"]
w.Capacity = int(pw.Capacity)
w.CurrentLoad = 0
w.Metadata = pw.Metadata
w.RegisteredAt = time.Now()
w.LastHeartbeat = time.Now()
_, err = w.client.SubmitResult(w.ctx, submitReq)
if err != nil {
log.Printf("Error submitting task result: %v", err)
}
}
}

func (w *Worker) getCurrentLoad() int32 {
w.loadMux.RLock()
defer w.loadMux.RUnlock()

return int32(w.currentLoad)
}

func (w *Worker) incrementLoad() {
w.loadMux.Lock()
defer w.loadMux.Unlock()

w.currentLoad++
}

func (w *Worker) decrementLoad() {
w.loadMux.Lock()
defer w.loadMux.Unlock()

if w.currentLoad > 0 {
w.currentLoad--
}
}

func (w *Worker) register() error {
req := w.buildRegisterRequest()

res, err := w.client.RegisterWorker(w.ctx, req)
if err != nil {
return err
}

w.id = res.WorkerId
return nil
}

func (w *Worker) buildRegisterRequest() *proto.RegisterWorkerRequest {
taskTypes := make([]string, 0, len(w.handlers))
for taskType := range w.handlers {
taskTypes = append(taskTypes, taskType)
}

return &proto.RegisterWorkerRequest{
Worker: &proto.Worker{
TaskTypes: taskTypes,
Capacity: int32(w.capacity),
Metadata: map[string]string{
"address": w.serverAddr,
},
},
}
}

func (w *Worker) buildFetchTasksRequest() *proto.FetchTaskRequest {
taskTypes := make([]string, 0, len(w.handlers))
for taskType := range w.handlers {
taskTypes = append(taskTypes, taskType)
}

return &proto.FetchTaskRequest{
WorkerId: w.id,
TaskTypes: taskTypes,
}
}

func (w *Worker) buildHeartbeatRequest() *proto.HeartbeatRequest {
return &proto.HeartbeatRequest{
WorkerId: w.id,
CurrentLoad: w.getCurrentLoad(),
}
}
Loading
Loading