Rate limit on registration
This commit is contained in:
		
							
								
								
									
										60
									
								
								db/db.go
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								db/db.go
									
									
									
									
									
								
							| @ -5,6 +5,8 @@ import ( | ||||
| 	"database/sql" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| @ -17,6 +19,64 @@ var ( | ||||
| 	ErrInvalidCredentials = errors.New("invalid username or password") | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	maxRegistrationsPerIP = 3              // Maximum registrations allowed per IP | ||||
| 	registrationWindow    = 24 * time.Hour // Time window for rate limiting | ||||
| ) | ||||
|  | ||||
| type registrationAttempt struct { | ||||
| 	count    int | ||||
| 	firstTry time.Time | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	registrationAttempts = make(map[string]*registrationAttempt) | ||||
| 	rateLimitMutex       sync.RWMutex | ||||
| ) | ||||
|  | ||||
| func CleanupOldAttempts() { | ||||
| 	rateLimitMutex.Lock() | ||||
| 	defer rateLimitMutex.Unlock() | ||||
|  | ||||
| 	now := time.Now() | ||||
| 	for ip, attempt := range registrationAttempts { | ||||
| 		if now.Sub(attempt.firstTry) > registrationWindow { | ||||
| 			delete(registrationAttempts, ip) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func CheckRegistrationLimit(ip string) error { | ||||
| 	rateLimitMutex.Lock() | ||||
| 	defer rateLimitMutex.Unlock() | ||||
|  | ||||
| 	now := time.Now() | ||||
| 	attempt, exists := registrationAttempts[ip] | ||||
|  | ||||
| 	if !exists { | ||||
| 		registrationAttempts[ip] = ®istrationAttempt{ | ||||
| 			count:    1, | ||||
| 			firstTry: now, | ||||
| 		} | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Reset if window has passed | ||||
| 	if now.Sub(attempt.firstTry) > registrationWindow { | ||||
| 		attempt.count = 1 | ||||
| 		attempt.firstTry = now | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if attempt.count >= maxRegistrationsPerIP { | ||||
| 		return fmt.Errorf("registration limit reached for this IP. Please try again in %v", | ||||
| 			registrationWindow-now.Sub(attempt.firstTry)) | ||||
| 	} | ||||
|  | ||||
| 	attempt.count++ | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func InitDB(dbPath string) error { | ||||
| 	var err error | ||||
| 	db, err = sql.Open("sqlite3", dbPath) | ||||
|  | ||||
							
								
								
									
										38
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								main.go
									
									
									
									
									
								
							| @ -30,13 +30,12 @@ type Player struct { | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	players      = make(map[int]*Player) | ||||
| 	actionQueue  = make(map[int][]*pb.Action) // Queue to store actions for each player | ||||
| 	playerConns  = make(map[int]net.Conn)     // Map to store player connections | ||||
| 	mu           sync.RWMutex                 // Add mutex for protecting shared maps | ||||
| 	chatHistory  = make([]*pb.ChatMessage, 0, 100) | ||||
| 	chatMutex    sync.RWMutex | ||||
| 	nextPlayerID = 1 // Assuming player IDs start from 1 | ||||
| 	players     = make(map[int]*Player) | ||||
| 	actionQueue = make(map[int][]*pb.Action) // Queue to store actions for each player | ||||
| 	playerConns = make(map[int]net.Conn)     // Map to store player connections | ||||
| 	mu          sync.RWMutex                 // Add mutex for protecting shared maps | ||||
| 	chatHistory = make([]*pb.ChatMessage, 0, 100) | ||||
| 	chatMutex   sync.RWMutex | ||||
| ) | ||||
|  | ||||
| func main() { | ||||
| @ -55,6 +54,15 @@ func main() { | ||||
| 	ticker := time.NewTicker(tickRate) | ||||
| 	defer ticker.Stop() | ||||
|  | ||||
| 	// Start registration attempt cleanup goroutine | ||||
| 	go func() { | ||||
| 		ticker := time.NewTicker(time.Hour) | ||||
| 		defer ticker.Stop() | ||||
| 		for range ticker.C { | ||||
| 			db.CleanupOldAttempts() | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// Handle incoming connections in a separate goroutine | ||||
| 	go func() { | ||||
| 		for { | ||||
| @ -76,6 +84,14 @@ func main() { | ||||
| func handleConnection(conn net.Conn) { | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	// Get client IP | ||||
| 	remoteAddr := conn.RemoteAddr().String() | ||||
| 	ip, _, err := net.SplitHostPort(remoteAddr) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to parse remote address: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Read initial message for player ID | ||||
| 	reader := bufio.NewReader(conn) | ||||
|  | ||||
| @ -130,6 +146,14 @@ func handleConnection(conn net.Conn) { | ||||
|  | ||||
| 	switch action.Type { | ||||
| 	case pb.Action_REGISTER: | ||||
| 		if err := db.CheckRegistrationLimit(ip); err != nil { | ||||
| 			response := &pb.ServerMessage{ | ||||
| 				AuthSuccess:  false, | ||||
| 				ErrorMessage: err.Error(), | ||||
| 			} | ||||
| 			writeMessage(conn, response) | ||||
| 			return | ||||
| 		} | ||||
| 		playerID, authErr = db.RegisterPlayer(action.Username, action.Password) | ||||
| 	case pb.Action_LOGIN: | ||||
| 		playerID, authErr = db.AuthenticatePlayer(action.Username, action.Password) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user