Rate limit on registration

This commit is contained in:
bdnugget 2025-01-19 22:06:41 +01:00
parent 3f7205d73e
commit 27da845b11
2 changed files with 91 additions and 7 deletions

View File

@ -5,6 +5,8 @@ import (
"database/sql"
"encoding/hex"
"errors"
"fmt"
"sync"
"time"
_ "github.com/mattn/go-sqlite3"
@ -17,6 +19,64 @@ var (
ErrInvalidCredentials = errors.New("invalid username or password")
)
const (
maxRegistrationsPerIP = 3 // Maximum registrations allowed per IP
registrationWindow = 24 * time.Hour // Time window for rate limiting
)
type registrationAttempt struct {
count int
firstTry time.Time
}
var (
registrationAttempts = make(map[string]*registrationAttempt)
rateLimitMutex sync.RWMutex
)
func CleanupOldAttempts() {
rateLimitMutex.Lock()
defer rateLimitMutex.Unlock()
now := time.Now()
for ip, attempt := range registrationAttempts {
if now.Sub(attempt.firstTry) > registrationWindow {
delete(registrationAttempts, ip)
}
}
}
func CheckRegistrationLimit(ip string) error {
rateLimitMutex.Lock()
defer rateLimitMutex.Unlock()
now := time.Now()
attempt, exists := registrationAttempts[ip]
if !exists {
registrationAttempts[ip] = &registrationAttempt{
count: 1,
firstTry: now,
}
return nil
}
// Reset if window has passed
if now.Sub(attempt.firstTry) > registrationWindow {
attempt.count = 1
attempt.firstTry = now
return nil
}
if attempt.count >= maxRegistrationsPerIP {
return fmt.Errorf("registration limit reached for this IP. Please try again in %v",
registrationWindow-now.Sub(attempt.firstTry))
}
attempt.count++
return nil
}
func InitDB(dbPath string) error {
var err error
db, err = sql.Open("sqlite3", dbPath)

26
main.go
View File

@ -36,7 +36,6 @@ var (
mu sync.RWMutex // Add mutex for protecting shared maps
chatHistory = make([]*pb.ChatMessage, 0, 100)
chatMutex sync.RWMutex
nextPlayerID = 1 // Assuming player IDs start from 1
)
func main() {
@ -55,6 +54,15 @@ func main() {
ticker := time.NewTicker(tickRate)
defer ticker.Stop()
// Start registration attempt cleanup goroutine
go func() {
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
for range ticker.C {
db.CleanupOldAttempts()
}
}()
// Handle incoming connections in a separate goroutine
go func() {
for {
@ -76,6 +84,14 @@ func main() {
func handleConnection(conn net.Conn) {
defer conn.Close()
// Get client IP
remoteAddr := conn.RemoteAddr().String()
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
log.Printf("Failed to parse remote address: %v", err)
return
}
// Read initial message for player ID
reader := bufio.NewReader(conn)
@ -130,6 +146,14 @@ func handleConnection(conn net.Conn) {
switch action.Type {
case pb.Action_REGISTER:
if err := db.CheckRegistrationLimit(ip); err != nil {
response := &pb.ServerMessage{
AuthSuccess: false,
ErrorMessage: err.Error(),
}
writeMessage(conn, response)
return
}
playerID, authErr = db.RegisterPlayer(action.Username, action.Password)
case pb.Action_LOGIN:
playerID, authErr = db.AuthenticatePlayer(action.Username, action.Password)