diff --git a/cmd/commander/main.go b/cmd/commander/main.go index d0b37d5..f7f0ce8 100644 --- a/cmd/commander/main.go +++ b/cmd/commander/main.go @@ -13,13 +13,8 @@ import ( func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - tokenManager, err := auth.NewTokenManager() - if err != nil { - log.Error().Err(err).Msg("Unsuccessful initialization of token manager") - return - } - aut := auth.NewAuthMiddleware(tokenManager) - uc := controller.NewPostController(tokenManager) + aut := auth.NewAuthMiddleware() + uc := controller.NewPostController() h := httphandler.NewHttpHandler(uc, aut) h.Start() diff --git a/cmd/queryer/main.go b/cmd/queryer/main.go index 4c3ab2b..e5bca0f 100644 --- a/cmd/queryer/main.go +++ b/cmd/queryer/main.go @@ -4,6 +4,7 @@ import ( "os" "github.com/L4B0MB4/EVTSRC/pkg/client" + tcpClient "github.com/L4B0MB4/EVTSRC/pkg/tcp/client" "github.com/PRYVT/posting/pkg/query/eventhandling" "github.com/PRYVT/posting/pkg/query/httphandler" "github.com/PRYVT/posting/pkg/query/httphandler/controller" @@ -32,21 +33,33 @@ func main() { log.Error().Err(err).Msg("Unsuccessful initialization of client") return } - tokenManager, err := auth.NewTokenManager() - if err != nil { - log.Error().Err(err).Msg("Unsuccessful initialization of token manager") - return - } eventRepo := utilsRepo.NewEventRepository(conn) userRepo := repository.NewUserRepository(conn) - uc := controller.NewPostController(userRepo, tokenManager) - aut := auth.NewAuthMiddleware(tokenManager) - h := httphandler.NewHttpHandler(uc, aut) - userEventHandler := eventhandling.NewPostEventHandler(userRepo) + uc := controller.NewPostController(userRepo, userEventHandler) + aut := auth.NewAuthMiddleware() + wsH := controller.NewWsController(userEventHandler) + h := httphandler.NewHttpHandler(uc, aut, wsH) eventPolling := eventpolling.NewEventPolling(c, eventRepo, userEventHandler) - go eventPolling.PollEvents() + tcpC, err := tcpClient.NewTcpEventClient() + if err != nil { + log.Error().Err(err).Msg("Unsuccessful initialization of tcp client") + return + } + channel := make(chan string, 1) + go tcpC.ListenForEvents(channel) + + eventPolling.PollEventsUntilEmpty() + go func() { + for { + select { + case event := <-channel: + log.Info().Msgf("Received event: %s", event) + eventPolling.PollEventsUntilEmpty() + } + } + }() h.Start() } diff --git a/go.mod b/go.mod index 5c8da37..f421eaf 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/PRYVT/posting go 1.23.1 require ( - github.com/L4B0MB4/EVTSRC v0.4.5 // indirect - github.com/PRYVT/utils v0.2.0 // indirect + github.com/L4B0MB4/EVTSRC v0.5.2 // indirect + github.com/PRYVT/utils v0.3.0-rc // indirect github.com/bytedance/sonic v1.12.2 // indirect github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect @@ -18,6 +18,7 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index 67eb211..d54126c 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,17 @@ github.com/L4B0MB4/EVTSRC v0.4.5 h1:HA4tp4fa/oCPTCl3gTD2FkRjo+nFKWm4rLmpudxcxXg= github.com/L4B0MB4/EVTSRC v0.4.5/go.mod h1:hpyNdNWqikZ6dcm8dhZAXgnAXZQNGAfXgRw902zjby0= +github.com/L4B0MB4/EVTSRC v0.5.1 h1:EB/lK0FTWtepToOtRFJdUhBf6tlb1L0bIRbRTMwArsQ= +github.com/L4B0MB4/EVTSRC v0.5.1/go.mod h1:hpyNdNWqikZ6dcm8dhZAXgnAXZQNGAfXgRw902zjby0= +github.com/L4B0MB4/EVTSRC v0.5.2 h1:bAOYlUmcZ2bg8rWIKnHLnxZPgXrPYyrQbtlz49BKlp4= +github.com/L4B0MB4/EVTSRC v0.5.2/go.mod h1:hpyNdNWqikZ6dcm8dhZAXgnAXZQNGAfXgRw902zjby0= github.com/PRYVT/utils v0.1.2 h1:U9qhq+18iIblQDrM4I0fmJkvlZ+BCY+DIjjKI4ebtlk= github.com/PRYVT/utils v0.1.2/go.mod h1:b7zk2FAGwJ8BPJx2JQ8qd+bA59g5EY7Y1vZQPWZHK3s= github.com/PRYVT/utils v0.2.0 h1:hWdHchXlGOYlJ1nfMmGffq/EjFn3ncvzTgsGCLUpiEE= github.com/PRYVT/utils v0.2.0/go.mod h1:j61GmoyWWXgnCq/laZTIJm4yhD0PreLDMZnYQqjSv7w= +github.com/PRYVT/utils v0.2.1 h1:GiTbziM3lqRLc4EWGV28+T/aKaY+B80KTqnkBklf9q0= +github.com/PRYVT/utils v0.2.1/go.mod h1:j61GmoyWWXgnCq/laZTIJm4yhD0PreLDMZnYQqjSv7w= +github.com/PRYVT/utils v0.3.0-rc h1:q5PlfgI0pu7Pv6b1A30BC/3lGIIhth2oggAxPpf/r40= +github.com/PRYVT/utils v0.3.0-rc/go.mod h1:j61GmoyWWXgnCq/laZTIJm4yhD0PreLDMZnYQqjSv7w= github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= @@ -36,6 +44,8 @@ github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVI github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/pkg/command/httphandler/controller/post_controller.go b/pkg/command/httphandler/controller/post_controller.go index c06e4a3..5b1fbc3 100644 --- a/pkg/command/httphandler/controller/post_controller.go +++ b/pkg/command/httphandler/controller/post_controller.go @@ -11,17 +11,16 @@ import ( ) type PostController struct { - tokenManager *auth.TokenManager } -func NewPostController(tokenManager *auth.TokenManager) *PostController { - return &PostController{tokenManager: tokenManager} +func NewPostController() *PostController { + return &PostController{} } func (ctrl *PostController) CreatePost(c *gin.Context) { token := auth.GetTokenFromHeader(c) - userUuid, err := ctrl.tokenManager.GetUserUuidFromToken(token) + userUuid, err := auth.GetUserUuidFromToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) return diff --git a/pkg/query/eventhandling/post.go b/pkg/query/eventhandling/post.go index 96b6972..e255d1b 100644 --- a/pkg/query/eventhandling/post.go +++ b/pkg/query/eventhandling/post.go @@ -1,24 +1,46 @@ package eventhandling import ( + "sync" + "github.com/L4B0MB4/EVTSRC/pkg/models" "github.com/PRYVT/posting/pkg/aggregates" "github.com/PRYVT/posting/pkg/query/store/repository" + ws "github.com/PRYVT/posting/pkg/query/websocket" "github.com/google/uuid" "github.com/rs/zerolog/log" ) -type UserEventHandler struct { - postRepo *repository.PostRepository +type PostEventHandler struct { + postRepo *repository.PostRepository + wsConnections []*ws.WebsocketConnection + mu sync.Mutex } -func NewPostEventHandler(postRepo *repository.PostRepository) *UserEventHandler { - return &UserEventHandler{ - postRepo: postRepo, +func NewPostEventHandler(postRepo *repository.PostRepository) *PostEventHandler { + return &PostEventHandler{ + postRepo: postRepo, + wsConnections: []*ws.WebsocketConnection{}, } } -func (eh *UserEventHandler) HandleEvent(event models.Event) error { +func (eh *PostEventHandler) AddWebsocketConnection(conn *ws.WebsocketConnection) { + eh.mu.Lock() + defer eh.mu.Unlock() + eh.wsConnections = append(eh.wsConnections, conn) +} + +func removeDisconnectedSockets(slice []*ws.WebsocketConnection) []*ws.WebsocketConnection { + output := []*ws.WebsocketConnection{} + for _, element := range slice { + if element.IsConnected { + output = append(output, element) + } + } + return output +} + +func (eh *PostEventHandler) HandleEvent(event models.Event) error { if event.AggregateType == "post" { ua, err := aggregates.NewPostAggregate(uuid.MustParse(event.AggregateId)) if err != nil { @@ -30,6 +52,19 @@ func (eh *UserEventHandler) HandleEvent(event models.Event) error { log.Err(err).Msg("Error while processing user event") return err } + for _, conn := range eh.wsConnections { + if !conn.IsAuthenticated { + continue + } + err := conn.WriteJSON(p) + if err != nil { + log.Warn().Err(err).Msg("Error while writing to websocket connection") + } + } + eh.mu.Lock() + defer eh.mu.Unlock() + eh.wsConnections = removeDisconnectedSockets(eh.wsConnections) + } return nil } diff --git a/pkg/query/httphandler/controller/post_controller.go b/pkg/query/httphandler/controller/post_controller.go index 04ed391..f85e46c 100644 --- a/pkg/query/httphandler/controller/post_controller.go +++ b/pkg/query/httphandler/controller/post_controller.go @@ -6,17 +6,17 @@ import ( "github.com/PRYVT/posting/pkg/models/query" "github.com/PRYVT/posting/pkg/query/store/repository" "github.com/PRYVT/posting/pkg/query/utils" - "github.com/PRYVT/utils/pkg/auth" + "github.com/PRYVT/utils/pkg/eventpolling" "github.com/gin-gonic/gin" ) type PostController struct { - postRepo *repository.PostRepository - tokenManager *auth.TokenManager + postRepo *repository.PostRepository + userEventH eventpolling.EventHanlder } -func NewPostController(userRepo *repository.PostRepository, tokenManager *auth.TokenManager) *PostController { - return &PostController{postRepo: userRepo, tokenManager: tokenManager} +func NewPostController(userRepo *repository.PostRepository, userEventH eventpolling.EventHanlder) *PostController { + return &PostController{postRepo: userRepo, userEventH: userEventH} } func (ctrl *PostController) GetPost(c *gin.Context) { diff --git a/pkg/query/httphandler/controller/websocket_controller.go b/pkg/query/httphandler/controller/websocket_controller.go new file mode 100644 index 0000000..5d2f367 --- /dev/null +++ b/pkg/query/httphandler/controller/websocket_controller.go @@ -0,0 +1,36 @@ +package controller + +import ( + "net/http" + + "github.com/PRYVT/posting/pkg/query/eventhandling" + ws "github.com/PRYVT/posting/pkg/query/websocket" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +type WSController struct { + userEventH *eventhandling.PostEventHandler +} + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +func NewWsController(userEventH *eventhandling.PostEventHandler) *WSController { + + upgrader.CheckOrigin = func(r *http.Request) bool { return true } + return &WSController{userEventH: userEventH} +} + +func (w *WSController) OnRequest(c *gin.Context) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Warn().Err(err).Msg("Error while upgrading connection") + + } else { + w.userEventH.AddWebsocketConnection(ws.NewWebsocketConnection(conn)) + } +} diff --git a/pkg/query/httphandler/handler.go b/pkg/query/httphandler/handler.go index c8ff8ea..8c2588a 100644 --- a/pkg/query/httphandler/handler.go +++ b/pkg/query/httphandler/handler.go @@ -15,9 +15,10 @@ type HttpHandler struct { router *gin.Engine postController *controller.PostController authMiddleware *auth.AuthMiddleware + wsController *controller.WSController } -func NewHttpHandler(c *controller.PostController, am *auth.AuthMiddleware) *HttpHandler { +func NewHttpHandler(c *controller.PostController, am *auth.AuthMiddleware, wsController *controller.WSController) *HttpHandler { r := gin.Default() srv := &http.Server{ Addr: "0.0.0.0" + ":" + "5520", @@ -28,6 +29,7 @@ func NewHttpHandler(c *controller.PostController, am *auth.AuthMiddleware) *Http httpServer: srv, postController: c, authMiddleware: am, + wsController: wsController, } handler.RegisterRoutes() return handler @@ -35,6 +37,7 @@ func NewHttpHandler(c *controller.PostController, am *auth.AuthMiddleware) *Http func (h *HttpHandler) RegisterRoutes() { h.router.Use(auth.CORSMiddleware()) + h.router.GET("/ws", h.wsController.OnRequest) h.router.Use(h.authMiddleware.AuthenticateMiddleware) { h.router.GET("posts/:postId", h.postController.GetPost) diff --git a/pkg/query/websocket/auth_req.go b/pkg/query/websocket/auth_req.go new file mode 100644 index 0000000..41d45a2 --- /dev/null +++ b/pkg/query/websocket/auth_req.go @@ -0,0 +1,9 @@ +package websocket + +import "encoding/json" + +type AuthRequest struct { + Token string `json:"token"` + Type string `json:"type"` + Data json.RawMessage +} diff --git a/pkg/query/websocket/websocket_connection.go b/pkg/query/websocket/websocket_connection.go new file mode 100644 index 0000000..3c10ace --- /dev/null +++ b/pkg/query/websocket/websocket_connection.go @@ -0,0 +1,68 @@ +package websocket + +import ( + "fmt" + + "github.com/PRYVT/utils/pkg/auth" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +type WebsocketConnection struct { + connection *websocket.Conn + IsConnected bool + IsAuthenticated bool + userUuid uuid.UUID +} + +func NewWebsocketConnection(conn *websocket.Conn) *WebsocketConnection { + wC := &WebsocketConnection{connection: conn} + go wC.ReadForDisconnect() + return wC +} + +func (wC *WebsocketConnection) WriteJSON(v interface{}) error { + if !wC.IsAuthenticated { + return fmt.Errorf("WebsocketConnection is not connected or authenticated") + } + err := wC.connection.WriteJSON(v) + if err != nil { + log.Warn().Err(err).Msg("Error while writing WriteJSON") + } + return err +} + +func (wC *WebsocketConnection) ReadForDisconnect() { + wC.IsConnected = true + for { + authRequest := AuthRequest{} + err := wC.connection.ReadJSON(&authRequest) + if err != nil { + log.Debug().Err(err).Msg("Error while reading from websocket connection") + wC.IsAuthenticated = false + wC.connection.Close() + wC.IsConnected = false + break + } else { + _, err = auth.VerifyToken(authRequest.Token) + if err != nil { + log.Debug().Err(err).Msg("Error while verifying token") + wC.IsAuthenticated = false + wC.connection.Close() + wC.IsConnected = false + break + } + uuid, err := auth.GetUserUuidFromToken(authRequest.Token) + if err != nil { + log.Debug().Err(err).Msg("Error while getting user uuid from token") + wC.IsAuthenticated = false + wC.connection.Close() + wC.IsConnected = false + break + } + wC.userUuid = uuid + wC.IsAuthenticated = true + } + } +}