diff --git a/go.mod b/go.mod index c17f9ad..71c2c51 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index a74f80d..49baffe 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVI github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/pkg/auth/token_manager.go b/pkg/auth/token_manager.go index ff4991e..1e04789 100644 --- a/pkg/auth/token_manager.go +++ b/pkg/auth/token_manager.go @@ -42,7 +42,7 @@ func GetUserUuidFromToken(tokenString string) (uuid.UUID, error) { } -func CreateToken(userUuid uuid.UUID) (string, error) { +func CreateToken(userUuid uuid.UUID, duration time.Duration) (string, error) { signingSecret, err := getSigningSecret() if err != nil { @@ -53,7 +53,7 @@ func CreateToken(userUuid uuid.UUID) (string, error) { "sub": userUuid, "iss": "pryvt", "aud": "local-audience", - "exp": time.Now().Add(time.Minute * 30).Unix(), + "exp": time.Now().Add(duration).Unix(), "iat": time.Now().Unix(), }) diff --git a/pkg/auth/token_manager_test.go b/pkg/auth/token_manager_test.go index 8195080..81ca25a 100644 --- a/pkg/auth/token_manager_test.go +++ b/pkg/auth/token_manager_test.go @@ -3,6 +3,7 @@ package auth import ( "os" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -15,7 +16,7 @@ func TestGetUserUuidFromToken(t *testing.T) { // Create a test UUID testUuid := uuid.New() - tokenString, err := CreateToken(testUuid) + tokenString, err := CreateToken(testUuid, 1*time.Second) assert.NoError(t, err) returnedUuid, err := GetUserUuidFromToken(tokenString) diff --git a/pkg/eventpolling/event_handler.go b/pkg/eventpolling/event_handler.go deleted file mode 100644 index b729347..0000000 --- a/pkg/eventpolling/event_handler.go +++ /dev/null @@ -1,7 +0,0 @@ -package eventpolling - -import "github.com/L4B0MB4/EVTSRC/pkg/models" - -type EventHanlder interface { - HandleEvent(event models.Event) error -} diff --git a/pkg/eventpolling/polling.go b/pkg/eventpolling/polling.go index e92b404..921352b 100644 --- a/pkg/eventpolling/polling.go +++ b/pkg/eventpolling/polling.go @@ -4,6 +4,7 @@ import ( "time" "github.com/L4B0MB4/EVTSRC/pkg/client" + "github.com/PRYVT/utils/pkg/interfaces" "github.com/PRYVT/utils/pkg/store/repository" "github.com/rs/zerolog/log" ) @@ -11,10 +12,10 @@ import ( type EventPolling struct { client *client.EventSourcingHttpClient eventRepo *repository.EventRepository - eventHandler EventHanlder + eventHandler interfaces.EventHandler } -func NewEventPolling(client *client.EventSourcingHttpClient, eventRepo *repository.EventRepository, eventHandler EventHanlder) *EventPolling { +func NewEventPolling(client *client.EventSourcingHttpClient, eventRepo *repository.EventRepository, eventHandler interfaces.EventHandler) *EventPolling { if client == nil || eventRepo == nil || eventHandler == nil { return nil } diff --git a/pkg/interfaces/event_handler.go b/pkg/interfaces/event_handler.go new file mode 100644 index 0000000..660ccbc --- /dev/null +++ b/pkg/interfaces/event_handler.go @@ -0,0 +1,8 @@ +package interfaces + +import "github.com/L4B0MB4/EVTSRC/pkg/models" + +type EventHandler interface { + HandleEvent(event models.Event) error + AddWebsocketConnection(conn WebsocketConnecter) +} diff --git a/pkg/interfaces/websocket_connector.go b/pkg/interfaces/websocket_connector.go new file mode 100644 index 0000000..abae76e --- /dev/null +++ b/pkg/interfaces/websocket_connector.go @@ -0,0 +1,8 @@ +package interfaces + +type WebsocketConnecter interface { + WriteJSON(v interface{}) error + ReadForDisconnect() + IsAuthenticated() bool + IsConnected() bool +} diff --git a/pkg/websocket/auth_req.go b/pkg/websocket/auth_req.go new file mode 100644 index 0000000..41d45a2 --- /dev/null +++ b/pkg/websocket/auth_req.go @@ -0,0 +1,9 @@ +package websocket + +import "encoding/json" + +type AuthRequest struct { + Token string `json:"token"` + Type string `json:"type"` + Data json.RawMessage +} diff --git a/pkg/websocket/websocket_connection.go b/pkg/websocket/websocket_connection.go new file mode 100644 index 0000000..8e47a3f --- /dev/null +++ b/pkg/websocket/websocket_connection.go @@ -0,0 +1,83 @@ +package websocket + +import ( + "fmt" + + "github.com/PRYVT/utils/pkg/auth" + "github.com/PRYVT/utils/pkg/interfaces" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +type WebsocketConnectionInterfacer interface { + WriteJSON(v interface{}) error + ReadForDisconnect() +} + +type WebsocketConnection struct { + connection *websocket.Conn + isConnected bool + isAuthenticated bool + userUuid uuid.UUID +} + +func NewWebsocketConnection(conn *websocket.Conn) interfaces.WebsocketConnecter { + wC := &WebsocketConnection{connection: conn} + go wC.ReadForDisconnect() + return wC +} + +func (wC *WebsocketConnection) IsConnected() bool { + return wC.isConnected +} + +func (wC *WebsocketConnection) IsAuthenticated() bool { + return wC.isAuthenticated +} + +func (wC *WebsocketConnection) WriteJSON(v interface{}) error { + if !wC.isAuthenticated { + return fmt.Errorf("WebsocketConnection is not connected or authenticated") + } + err := wC.connection.WriteJSON(v) + if err != nil { + log.Warn().Err(err).Msg("Error while writing WriteJSON") + } + return err +} + +func (wC *WebsocketConnection) ReadForDisconnect() { + wC.isConnected = true + for { + authRequest := AuthRequest{} + err := wC.connection.ReadJSON(&authRequest) + if err != nil { + log.Debug().Err(err).Msg("Error while reading from websocket connection") + wC.isAuthenticated = false + wC.connection.Close() + wC.isConnected = false + break + } else { + log.Debug().Interface("authReq", authRequest).Msg("Received auth request") + _, err = auth.VerifyToken(authRequest.Token) + if err != nil { + log.Debug().Err(err).Msg("Error while verifying token") + wC.isAuthenticated = false + wC.connection.Close() + wC.isConnected = false + break + } + uuid, err := auth.GetUserUuidFromToken(authRequest.Token) + if err != nil { + log.Debug().Err(err).Msg("Error while getting user uuid from token") + wC.isAuthenticated = false + wC.connection.Close() + wC.isConnected = false + break + } + wC.userUuid = uuid + wC.isAuthenticated = true + } + } +} diff --git a/pkg/websocket/websocket_controller.go b/pkg/websocket/websocket_controller.go new file mode 100644 index 0000000..1137d49 --- /dev/null +++ b/pkg/websocket/websocket_controller.go @@ -0,0 +1,35 @@ +package websocket + +import ( + "net/http" + + "github.com/PRYVT/utils/pkg/interfaces" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/rs/zerolog/log" +) + +type WSController struct { + eventHandler interfaces.EventHandler +} + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +func NewWsController(eventHandler interfaces.EventHandler) *WSController { + + upgrader.CheckOrigin = func(r *http.Request) bool { return true } + return &WSController{eventHandler: eventHandler} +} + +func (w *WSController) OnRequest(c *gin.Context) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Warn().Err(err).Msg("Error while upgrading connection") + + } else { + w.eventHandler.AddWebsocketConnection(NewWebsocketConnection(conn)) + } +}