diff --git a/cmd/tls/handlers/post.go b/cmd/tls/handlers/post.go index ce43f97b..6b0d75d4 100644 --- a/cmd/tls/handlers/post.go +++ b/cmd/tls/handlers/post.go @@ -149,6 +149,14 @@ func (h *HandlersTLS) ConfigHandler(w http.ResponseWriter, r *http.Request) { } // We need to update the node info in another go routine if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil { + // Check if node belongs to the environment + if node.EnvironmentID != env.ID { + log.Warn().Msgf("node UUID: %s in %s environment does not belong to the environment", node.UUID, env.Name) + response = types.ConfigResponse{NodeInvalid: true} + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, response) + return + } + // Node and environment match, so we can proceed to update the node ip := utils.GetIP(r) if ip == node.IPAddress { ip = "" @@ -236,9 +244,17 @@ func (h *HandlersTLS) LogHandler(w http.ResponseWriter, r *http.Request) { } }() var nodeInvalid bool + var response types.LogResponse // Check if provided node_key is valid and if so, update node - node, err := h.Nodes.GetByKey(t.NodeKey) - if err == nil { + if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil { + // Check if node belongs to the environment + if node.EnvironmentID != env.ID { + log.Warn().Msgf("node UUID: %s in %s environment does not belong to the environment", node.UUID, env.Name) + response = types.LogResponse{NodeInvalid: true} + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, response) + return + } + // Node and environment match, so we can proceed to update the node nodeInvalid = false // Record ingested data requestSize.WithLabelValues(string(env.UUID), "LogHandler").Observe(float64(len(body))) @@ -254,7 +270,7 @@ func (h *HandlersTLS) LogHandler(w http.ResponseWriter, r *http.Request) { nodeInvalid = true } // Prepare response - response := types.LogResponse{NodeInvalid: nodeInvalid} + response = types.LogResponse{NodeInvalid: nodeInvalid} // Debug if (*h.EnvsMap)[env.Name].DebugHTTP { log.Debug().Msgf("Response: %+v", response) @@ -301,13 +317,22 @@ func (h *HandlersTLS) QueryReadHandler(w http.ResponseWriter, r *http.Request) { return } var nodeInvalid, accelerate bool + var response interface{} qs := make(queries.QueryReadQueries) // Check if provided node_key is valid and if so, update node if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil { + // Check if node belongs to the environment + if node.EnvironmentID != env.ID { + log.Warn().Msgf("node UUID: %s in %s environment does not belong to the environment", node.UUID, env.Name) + response = types.ConfigResponse{NodeInvalid: true} + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, response) + return + } + // Node and environment match, so we can proceed // Record ingested data requestSize.WithLabelValues(string(env.UUID), "QueryRead").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for QueryReadHandler endpoint", node.UUID, env.Name, len(body)) - + // Get queries and update node nodeInvalid = false qs, accelerate, err = h.Queries.NodeQueries(node) if err != nil { @@ -325,8 +350,7 @@ func (h *HandlersTLS) QueryReadHandler(w http.ResponseWriter, r *http.Request) { nodeInvalid = true accelerate = false } - // Prepare response and serialize queries - var response interface{} + // Serialize queries if accelerate { sAccelerate := int((*h.SettingsMap)[settings.AcceleratedSeconds].Integer) response = types.AcceleratedQueryReadResponse{Queries: qs, Accelerate: sAccelerate, NodeInvalid: nodeInvalid} @@ -379,8 +403,17 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request) return } var nodeInvalid bool + var response types.QueryWriteResponse // Check if provided node_key is valid and if so, update node if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil { + // Check if node belongs to the environment + if node.EnvironmentID != env.ID { + log.Warn().Msgf("node UUID: %s in %s environment does not belong to the environment", node.UUID, env.Name) + response = types.QueryWriteResponse{NodeInvalid: true} + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, response) + return + } + // Node and environment match, so we can proceed // Record ingested data requestSize.WithLabelValues(string(env.UUID), "QueryWrite").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for QueryWriteHandler endpoint", node.UUID, env.Name, len(body)) @@ -415,7 +448,7 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request) nodeInvalid = true } // Prepare response - response := types.QueryWriteResponse{NodeInvalid: nodeInvalid} + response = types.QueryWriteResponse{NodeInvalid: nodeInvalid} // Debug HTTP if (*h.EnvsMap)[env.Name].DebugHTTP { log.Debug().Msgf("Response: %+v", response) @@ -611,12 +644,21 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) { } initCarve := false var carveSessionID string + var response types.CarveInitResponse // Check if provided node_key is valid and if so, update node if node, err := h.Nodes.GetByKey(t.NodeKey); err == nil { + // Check if node belongs to the environment + if node.EnvironmentID != env.ID { + log.Warn().Msgf("node UUID: %s in %s environment does not belong to the environment", node.UUID, env.Name) + response = types.CarveInitResponse{Success: false, SessionID: ""} + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, response) + return + } + // Node and environment match, so we can proceed // Record ingested data requestSize.WithLabelValues(string(env.UUID), "CarveInit").Observe(float64(len(body))) log.Debug().Msgf("node UUID: %s in %s environment ingested %d bytes for CarveInitHandler endpoint", node.UUID, env.Name, len(body)) - + // Initialize carve initCarve = true carveSessionID = generateCarveSessionID() // Process carve init @@ -632,7 +674,7 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) { h.WriteHandler.addEvent(lastSeenUpdate{NodeID: node.ID, IP: ip}) } // Prepare response - response := types.CarveInitResponse{Success: initCarve, SessionID: carveSessionID} + response = types.CarveInitResponse{Success: initCarve, SessionID: carveSessionID} // Debug HTTP if (*h.EnvsMap)[env.Name].DebugHTTP { log.Debug().Msgf("Response: %+v", response)