From 27da845b11f968aa77139c871becc6ce4c688038 Mon Sep 17 00:00:00 2001 From: bdnugget Date: Sun, 19 Jan 2025 22:06:41 +0100 Subject: [PATCH] Rate limit on registration --- db/db.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 38 ++++++++++++++++++++++++++++------- 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/db/db.go b/db/db.go index 427a308..a2ba9c4 100644 --- a/db/db.go +++ b/db/db.go @@ -5,6 +5,8 @@ import ( "database/sql" "encoding/hex" "errors" + "fmt" + "sync" "time" _ "github.com/mattn/go-sqlite3" @@ -17,6 +19,64 @@ var ( ErrInvalidCredentials = errors.New("invalid username or password") ) +const ( + maxRegistrationsPerIP = 3 // Maximum registrations allowed per IP + registrationWindow = 24 * time.Hour // Time window for rate limiting +) + +type registrationAttempt struct { + count int + firstTry time.Time +} + +var ( + registrationAttempts = make(map[string]*registrationAttempt) + rateLimitMutex sync.RWMutex +) + +func CleanupOldAttempts() { + rateLimitMutex.Lock() + defer rateLimitMutex.Unlock() + + now := time.Now() + for ip, attempt := range registrationAttempts { + if now.Sub(attempt.firstTry) > registrationWindow { + delete(registrationAttempts, ip) + } + } +} + +func CheckRegistrationLimit(ip string) error { + rateLimitMutex.Lock() + defer rateLimitMutex.Unlock() + + now := time.Now() + attempt, exists := registrationAttempts[ip] + + if !exists { + registrationAttempts[ip] = ®istrationAttempt{ + count: 1, + firstTry: now, + } + return nil + } + + // Reset if window has passed + if now.Sub(attempt.firstTry) > registrationWindow { + attempt.count = 1 + attempt.firstTry = now + return nil + } + + if attempt.count >= maxRegistrationsPerIP { + return fmt.Errorf("registration limit reached for this IP. Please try again in %v", + registrationWindow-now.Sub(attempt.firstTry)) + } + + attempt.count++ + return nil +} + func InitDB(dbPath string) error { var err error db, err = sql.Open("sqlite3", dbPath) diff --git a/main.go b/main.go index e2b9cb1..3328bd9 100644 --- a/main.go +++ b/main.go @@ -30,13 +30,12 @@ type Player struct { } var ( - players = make(map[int]*Player) - actionQueue = make(map[int][]*pb.Action) // Queue to store actions for each player - playerConns = make(map[int]net.Conn) // Map to store player connections - mu sync.RWMutex // Add mutex for protecting shared maps - chatHistory = make([]*pb.ChatMessage, 0, 100) - chatMutex sync.RWMutex - nextPlayerID = 1 // Assuming player IDs start from 1 + players = make(map[int]*Player) + actionQueue = make(map[int][]*pb.Action) // Queue to store actions for each player + playerConns = make(map[int]net.Conn) // Map to store player connections + mu sync.RWMutex // Add mutex for protecting shared maps + chatHistory = make([]*pb.ChatMessage, 0, 100) + chatMutex sync.RWMutex ) func main() { @@ -55,6 +54,15 @@ func main() { ticker := time.NewTicker(tickRate) defer ticker.Stop() + // Start registration attempt cleanup goroutine + go func() { + ticker := time.NewTicker(time.Hour) + defer ticker.Stop() + for range ticker.C { + db.CleanupOldAttempts() + } + }() + // Handle incoming connections in a separate goroutine go func() { for { @@ -76,6 +84,14 @@ func main() { func handleConnection(conn net.Conn) { defer conn.Close() + // Get client IP + remoteAddr := conn.RemoteAddr().String() + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + log.Printf("Failed to parse remote address: %v", err) + return + } + // Read initial message for player ID reader := bufio.NewReader(conn) @@ -130,6 +146,14 @@ func handleConnection(conn net.Conn) { switch action.Type { case pb.Action_REGISTER: + if err := db.CheckRegistrationLimit(ip); err != nil { + response := &pb.ServerMessage{ + AuthSuccess: false, + ErrorMessage: err.Error(), + } + writeMessage(conn, response) + return + } playerID, authErr = db.RegisterPlayer(action.Username, action.Password) case pb.Action_LOGIN: playerID, authErr = db.AuthenticatePlayer(action.Username, action.Password)