From 28449287695dc9b4697f6db0d807878b5d3f4b38 Mon Sep 17 00:00:00 2001 From: nilo Date: Thu, 22 May 2025 10:27:05 -0300 Subject: [PATCH] websocket second version --- controllers/websocketController.go | 95 +++++++++++++++++++++--------- main.go | 12 +++- 2 files changed, 77 insertions(+), 30 deletions(-) diff --git a/controllers/websocketController.go b/controllers/websocketController.go index 805087c..1a900ec 100644 --- a/controllers/websocketController.go +++ b/controllers/websocketController.go @@ -2,8 +2,10 @@ package controllers import ( "encoding/json" + "fmt" "log" "sync" + "time" "github.com/gofiber/websocket/v2" ) @@ -12,6 +14,7 @@ import ( type WebSocketManager struct { connections map[*websocket.Conn]bool mutex sync.RWMutex + broadcast chan WebSocketMessage } // WebSocketMessage represents the structure of messages @@ -24,6 +27,37 @@ type WebSocketMessage struct { // Global instance of WebSocketManager var WSManager = &WebSocketManager{ connections: make(map[*websocket.Conn]bool), + broadcast: make(chan WebSocketMessage, 100), // Buffer size of 100 +} + +func init() { + // Start the broadcast handler + go WSManager.handleBroadcasts() +} + +func (m *WebSocketManager) handleBroadcasts() { + for message := range m.broadcast { + jsonMessage, err := json.Marshal(message) + if err != nil { + log.Printf("Error marshaling message: %v", err) + continue + } + + m.mutex.RLock() + for conn := range m.connections { + // Send message asynchronously + go func(c *websocket.Conn) { + writeTimeout := time.Now().Add(time.Second * 5) + c.SetWriteDeadline(writeTimeout) + + if err := c.WriteMessage(websocket.TextMessage, jsonMessage); err != nil { + log.Printf("Error sending message: %v", err) + m.removeConnection(c) + } + }(conn) + } + m.mutex.RUnlock() + } } // BroadcastMessage sends a message to all connected clients @@ -33,24 +67,13 @@ func (m *WebSocketManager) BroadcastMessage(command, channel, text string) { Channel: channel, Text: text, } - - jsonMessage, err := json.Marshal(message) - if err != nil { - log.Printf("Error marshaling message: %v", err) - return - } - m.mutex.RLock() - defer m.mutex.RUnlock() - - for conn := range m.connections { - if err := conn.WriteMessage(websocket.TextMessage, jsonMessage); err != nil { - log.Printf("Error broadcasting message: %v", err) - // Remove failed connection - m.mutex.RUnlock() - m.removeConnection(conn) - m.mutex.RLock() - } + // Non-blocking send to broadcast channel + select { + case m.broadcast <- message: + // Message queued successfully + default: + log.Println("Broadcast channel full, message dropped") } } @@ -61,7 +84,7 @@ func (m *WebSocketManager) SendMessageToClient(conn *websocket.Conn, command, ch Channel: channel, Text: text, } - + jsonMessage, err := json.Marshal(message) if err != nil { return err @@ -71,6 +94,8 @@ func (m *WebSocketManager) SendMessageToClient(conn *websocket.Conn, command, ch defer m.mutex.RUnlock() if _, exists := m.connections[conn]; exists { + writeTimeout := time.Now().Add(time.Second * 5) + conn.SetWriteDeadline(writeTimeout) return conn.WriteMessage(websocket.TextMessage, jsonMessage) } return nil @@ -91,6 +116,9 @@ func (m *WebSocketManager) removeConnection(conn *websocket.Conn) { // WebsocketHandler handles WebSocket connections func WebsocketHandler(c *websocket.Conn) { + // Set read deadline + c.SetReadDeadline(time.Now().Add(time.Second * 60)) // 1 minute timeout + // Add the connection to our manager WSManager.addConnection(c) defer WSManager.removeConnection(c) @@ -103,17 +131,26 @@ func WebsocketHandler(c *websocket.Conn) { break } - // Try to parse the incoming message as JSON - var wsMessage WebSocketMessage - if err := json.Unmarshal(message, &wsMessage); err != nil { - log.Printf("Error parsing message: %v", err) - continue - } + // Process message asynchronously + go func() { + var wsMessage WebSocketMessage + if err := json.Unmarshal(message, &wsMessage); err != nil { + log.Printf("Error parsing message: %v", err) + return + } - // Echo the message back to the client - if err := WSManager.SendMessageToClient(c, wsMessage.Command, wsMessage.Channel, wsMessage.Text); err != nil { - log.Println("Error writing message:", err) - break - } + switch wsMessage.Command { + case "chat": + WSManager.BroadcastMessage("chat", wsMessage.Channel, wsMessage.Text) + case "join": + WSManager.BroadcastMessage("system", wsMessage.Channel, + fmt.Sprintf("User joined channel %s", wsMessage.Channel)) + case "leave": + WSManager.BroadcastMessage("system", wsMessage.Channel, + fmt.Sprintf("User left channel %s", wsMessage.Channel)) + default: + WSManager.SendMessageToClient(c, "error", "", "Unknown command") + } + }() } } diff --git a/main.go b/main.go index 38bdf29..a3ce341 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/websocket/v2" ) func main() { @@ -81,10 +82,19 @@ func main() { })) // Connects to database - if err := database.ConnectDB(); err != nil { + if err = database.ConnectDB(); err != nil { panic("Could not connect to database") } + // WebSocket configuration + app.Use("/ws", func(c *fiber.Ctx) error { + if websocket.IsWebSocketUpgrade(c) { + c.Locals("allowed", true) + return c.Next() + } + return fiber.ErrUpgradeRequired + }) + // Setup routes routes.Setup(app)