ws-go-server/main.go
2025-05-18 23:54:04 +02:00

235 lines
5.8 KiB
Go

package main
import (
"btclock/broker"
_ "btclock/clients"
"btclock/handlers"
"btclock/modules"
ws "btclock/websocket"
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/vmihailenco/msgpack/v5"
)
type Message struct {
Type string `msgpack:"type" json:"type"`
EventType string `msgpack:"eventType" json:"eventType"`
Currencies []string `msgpack:"currencies" json:"currencies"`
}
type App struct {
app *fiber.App
eventBroker *broker.EventBroker
channelManager *ws.ChannelManager
}
func NewApp() *App {
return &App{
app: fiber.New(),
}
}
func (a *App) initializeModules() error {
a.eventBroker = broker.NewEventBroker()
m := modules.Registry.GetAllModules()
log.Printf("Initializing %d modules", len(m))
for _, module := range m {
if err := module.Init(a.eventBroker); err != nil {
return fmt.Errorf("failed to initialize module %s: %w", module.ID(), err)
}
log.Printf("Module initialized: %s", module.ID())
}
for _, module := range m {
go func(m modules.Module) {
if err := m.Start(context.Background()); err != nil {
log.Printf("Module %s error: %v", m.ID(), err)
}
}(module)
}
return nil
}
func (a *App) setupWebSocketHandlers() {
a.channelManager = ws.NewChannelManager()
go a.channelManager.Start()
// V2 handler setup
v2handler := &handlers.WebSocketV2Handler{Cm: a.channelManager}
a.eventBroker.Register("ticker", v2handler)
a.eventBroker.Register("mempool-fee-rate", v2handler)
a.eventBroker.Register("mempool-block", v2handler)
// V1 handler setup
v1handler := &handlers.WebSocketV1Handler{Cm: a.channelManager}
a.eventBroker.Register("ticker", v1handler)
a.eventBroker.Register("mempool-fee-rate", v1handler)
a.eventBroker.Register("mempool-block", v1handler)
}
func (a *App) setupRoutes() {
// Static file serving
a.app.Use(filesystem.New(filesystem.Config{
Root: http.Dir("./static"),
Index: "index.html",
Browse: false,
}))
// WebSocket routes
a.app.Get("/api/v1/ws", websocket.New(a.handleWebSocketV1))
a.app.Get("/api/v2/ws", websocket.New(a.handleWebSocket))
}
func (a *App) handleWebSocketV1(c *websocket.Conn) {
client := &ws.Client{
Conn: c,
Channels: make(map[string]bool),
}
a.channelManager.Register(client)
defer a.channelManager.Unregister(client)
// Subscribe to default channels
client.Subscribe("blockheight-v1", a.channelManager)
client.Subscribe("blockfee-v1", a.channelManager)
client.Subscribe("price-v1", a.channelManager)
for {
_, _, err := c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
log.Printf("WebSocket connection closed: %v", err)
}
break
}
}
}
func (a *App) handleWebSocket(c *websocket.Conn) {
client := &ws.Client{
Conn: c,
Channels: make(map[string]bool),
}
a.channelManager.Register(client)
defer a.channelManager.Unregister(client)
for {
messageType, message, err := c.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
log.Printf("WebSocket connection closed: %v", err)
}
break
}
if messageType == websocket.BinaryMessage {
if err := a.handleBinaryMessage(client, message); err != nil {
log.Printf("Error handling message: %v", err)
}
}
}
}
func (a *App) handleBinaryMessage(client *ws.Client, message []byte) error {
var msg Message
if err := msgpack.Unmarshal(message, &msg); err != nil {
return fmt.Errorf("error unmarshalling message: %w", err)
}
switch msg.Type {
case "subscribe":
return a.handleSubscribe(client, msg)
case "unsubscribe":
return a.handleUnsubscribe(client, msg)
default:
return fmt.Errorf("unknown message type: %s", msg.Type)
}
}
func (a *App) handleSubscribe(client *ws.Client, msg Message) error {
if msg.EventType == "price" {
for _, currency := range msg.Currencies {
channel := fmt.Sprintf("price:%s", currency)
client.Subscribe(channel, a.channelManager)
log.Printf("[%s] Subscribed to channel: %s", client.Conn.RemoteAddr().String(), channel)
}
} else {
client.Subscribe(msg.EventType, a.channelManager)
log.Printf("[%s] Subscribed to channel: %s", client.Conn.RemoteAddr().String(), msg.EventType)
}
return nil
}
func (a *App) handleUnsubscribe(client *ws.Client, msg Message) error {
if msg.EventType == "price" {
for _, currency := range msg.Currencies {
channel := fmt.Sprintf("price:%s", currency)
client.Unsubscribe(channel)
log.Printf("[%s] Unsubscribed from channel: %s", client.Conn.RemoteAddr().String(), channel)
}
} else {
client.Unsubscribe(msg.EventType)
log.Printf("[%s] Unsubscribed from channel: %s", client.Conn.RemoteAddr().String(), msg.EventType)
}
return nil
}
func (a *App) start() error {
port := os.Getenv("PORT")
if port == "" {
port = "80"
}
// Create a channel to listen for OS signals
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Start the server in a goroutine
go func() {
if err := a.app.Listen(":" + port); err != nil {
log.Printf("Server error: %v", err)
}
}()
// Wait for interrupt signal
<-sigChan
log.Println("Shutting down server...")
// Create a timeout context for graceful shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Attempt graceful shutdown
if err := a.app.ShutdownWithContext(ctx); err != nil {
return fmt.Errorf("error during server shutdown: %w", err)
}
return nil
}
func main() {
app := NewApp()
if err := app.initializeModules(); err != nil {
log.Fatalf("Failed to initialize modules: %v", err)
}
app.setupWebSocketHandlers()
app.setupRoutes()
if err := app.start(); err != nil {
log.Fatalf("Server error: %v", err)
}
}