Rate limit on registration
This commit is contained in:
		
							
								
								
									
										60
									
								
								db/db.go
									
									
									
									
									
								
							
							
						
						
									
										60
									
								
								db/db.go
									
									
									
									
									
								
							| @ -5,6 +5,8 @@ import ( | |||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"encoding/hex" | 	"encoding/hex" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	_ "github.com/mattn/go-sqlite3" | 	_ "github.com/mattn/go-sqlite3" | ||||||
| @ -17,6 +19,64 @@ var ( | |||||||
| 	ErrInvalidCredentials = errors.New("invalid username or password") | 	ErrInvalidCredentials = errors.New("invalid username or password") | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	maxRegistrationsPerIP = 3              // Maximum registrations allowed per IP | ||||||
|  | 	registrationWindow    = 24 * time.Hour // Time window for rate limiting | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type registrationAttempt struct { | ||||||
|  | 	count    int | ||||||
|  | 	firstTry time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	registrationAttempts = make(map[string]*registrationAttempt) | ||||||
|  | 	rateLimitMutex       sync.RWMutex | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func CleanupOldAttempts() { | ||||||
|  | 	rateLimitMutex.Lock() | ||||||
|  | 	defer rateLimitMutex.Unlock() | ||||||
|  |  | ||||||
|  | 	now := time.Now() | ||||||
|  | 	for ip, attempt := range registrationAttempts { | ||||||
|  | 		if now.Sub(attempt.firstTry) > registrationWindow { | ||||||
|  | 			delete(registrationAttempts, ip) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CheckRegistrationLimit(ip string) error { | ||||||
|  | 	rateLimitMutex.Lock() | ||||||
|  | 	defer rateLimitMutex.Unlock() | ||||||
|  |  | ||||||
|  | 	now := time.Now() | ||||||
|  | 	attempt, exists := registrationAttempts[ip] | ||||||
|  |  | ||||||
|  | 	if !exists { | ||||||
|  | 		registrationAttempts[ip] = ®istrationAttempt{ | ||||||
|  | 			count:    1, | ||||||
|  | 			firstTry: now, | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Reset if window has passed | ||||||
|  | 	if now.Sub(attempt.firstTry) > registrationWindow { | ||||||
|  | 		attempt.count = 1 | ||||||
|  | 		attempt.firstTry = now | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if attempt.count >= maxRegistrationsPerIP { | ||||||
|  | 		return fmt.Errorf("registration limit reached for this IP. Please try again in %v", | ||||||
|  | 			registrationWindow-now.Sub(attempt.firstTry)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	attempt.count++ | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func InitDB(dbPath string) error { | func InitDB(dbPath string) error { | ||||||
| 	var err error | 	var err error | ||||||
| 	db, err = sql.Open("sqlite3", dbPath) | 	db, err = sql.Open("sqlite3", dbPath) | ||||||
|  | |||||||
							
								
								
									
										26
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								main.go
									
									
									
									
									
								
							| @ -36,7 +36,6 @@ var ( | |||||||
| 	mu          sync.RWMutex                 // Add mutex for protecting shared maps | 	mu          sync.RWMutex                 // Add mutex for protecting shared maps | ||||||
| 	chatHistory = make([]*pb.ChatMessage, 0, 100) | 	chatHistory = make([]*pb.ChatMessage, 0, 100) | ||||||
| 	chatMutex   sync.RWMutex | 	chatMutex   sync.RWMutex | ||||||
| 	nextPlayerID = 1 // Assuming player IDs start from 1 |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| @ -55,6 +54,15 @@ func main() { | |||||||
| 	ticker := time.NewTicker(tickRate) | 	ticker := time.NewTicker(tickRate) | ||||||
| 	defer ticker.Stop() | 	defer ticker.Stop() | ||||||
|  |  | ||||||
|  | 	// Start registration attempt cleanup goroutine | ||||||
|  | 	go func() { | ||||||
|  | 		ticker := time.NewTicker(time.Hour) | ||||||
|  | 		defer ticker.Stop() | ||||||
|  | 		for range ticker.C { | ||||||
|  | 			db.CleanupOldAttempts() | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
| 	// Handle incoming connections in a separate goroutine | 	// Handle incoming connections in a separate goroutine | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
| @ -76,6 +84,14 @@ func main() { | |||||||
| func handleConnection(conn net.Conn) { | func handleConnection(conn net.Conn) { | ||||||
| 	defer conn.Close() | 	defer conn.Close() | ||||||
|  |  | ||||||
|  | 	// Get client IP | ||||||
|  | 	remoteAddr := conn.RemoteAddr().String() | ||||||
|  | 	ip, _, err := net.SplitHostPort(remoteAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Printf("Failed to parse remote address: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Read initial message for player ID | 	// Read initial message for player ID | ||||||
| 	reader := bufio.NewReader(conn) | 	reader := bufio.NewReader(conn) | ||||||
|  |  | ||||||
| @ -130,6 +146,14 @@ func handleConnection(conn net.Conn) { | |||||||
|  |  | ||||||
| 	switch action.Type { | 	switch action.Type { | ||||||
| 	case pb.Action_REGISTER: | 	case pb.Action_REGISTER: | ||||||
|  | 		if err := db.CheckRegistrationLimit(ip); err != nil { | ||||||
|  | 			response := &pb.ServerMessage{ | ||||||
|  | 				AuthSuccess:  false, | ||||||
|  | 				ErrorMessage: err.Error(), | ||||||
|  | 			} | ||||||
|  | 			writeMessage(conn, response) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 		playerID, authErr = db.RegisterPlayer(action.Username, action.Password) | 		playerID, authErr = db.RegisterPlayer(action.Username, action.Password) | ||||||
| 	case pb.Action_LOGIN: | 	case pb.Action_LOGIN: | ||||||
| 		playerID, authErr = db.AuthenticatePlayer(action.Username, action.Password) | 		playerID, authErr = db.AuthenticatePlayer(action.Username, action.Password) | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user