diff --git a/hub.go b/hub.go index 2043185..b7a289d 100644 --- a/hub.go +++ b/hub.go @@ -1,11 +1,21 @@ package main +// HTTP handlers and hub run loop integration for RadChat. +// +// This file contains: +// - WebSocket and HTTP endpoints (username check, user count, file upload/download) +// - Helpers to send messages to specific clients and broadcast user lists +// - The hub Run loop that manages registration, unregistration, and broadcasting +// +// The actual Hub/Client definitions live in the server package. Here we glue +// them to HTTP. import ( "encoding/json" "fmt" "io" "log" "net/http" + "net/url" "os" "path/filepath" "radchat/server" @@ -13,6 +23,7 @@ import ( "time" ) +// IsUsernameTaken returns true if the given username already exists among connected clients (case-insensitive). func IsUsernameTaken(h *server.Hub, username string) bool { h.Mutex.RLock() defer h.Mutex.RUnlock() @@ -25,6 +36,7 @@ func IsUsernameTaken(h *server.Hub, username string) bool { return false } +// HandleUserCountCheck responds with the current number of connected users. func HandleUserCountCheck(hub *server.Hub, w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -41,6 +53,7 @@ func HandleUserCountCheck(hub *server.Hub, w http.ResponseWriter, r *http.Reques } } +// HandleUsernameCheck validates a requested username and checks for collisions. func HandleUsernameCheck(hub *server.Hub, w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -171,7 +184,7 @@ func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration, return } - err := r.ParseMultipartForm(32 << 20) + err := r.ParseMultipartForm(256 << 20) if err != nil { http.Error(w, "Error parsing form", http.StatusBadRequest) return @@ -199,21 +212,15 @@ func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration, continue } - fmt.Println(fileHeader.Filename) filename := SanitizeFilename(fileHeader.Filename) - fmt.Println(filename) // get filename before the extension fileExt := filepath.Ext(filename) - fmt.Println(fileExt) fileName := strings.TrimSuffix(filename, fileExt) - fmt.Println(fileName) fileFullName := fileName + "_" + time.Now().Format("20060102150405") + fileExt - fmt.Println(fileFullName) filePath := filepath.Join(filesDir, fileFullName) - fmt.Println(filePath) err = os.WriteFile(filePath, fileBytes, 0644) if err != nil { @@ -223,8 +230,12 @@ func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration, uploadedFiles = append(uploadedFiles, fileFullName) + fileFullName = url.PathEscape(fileFullName) fileLocation := filepath.Join(r.URL.Path, fileFullName) - fileLink := "https://" + filepath.Join(r.Host, fileLocation) + + scheme := GetScheme(r) + + fileLink := scheme + "://" + filepath.Join(r.Host, fileLocation) // find our client and send a message as them clients := h.Clients @@ -294,7 +305,6 @@ func HandleFileDownload(filesDir string, w http.ResponseWriter, r *http.Request) } filePath := filepath.Join(filesDir, filename) - fmt.Println(filePath) absUploadsDir, err := filepath.Abs(filesDir) if err != nil { diff --git a/utils.go b/utils.go index a422f78..3f00a2d 100644 --- a/utils.go +++ b/utils.go @@ -1,15 +1,19 @@ package main +// Utilities for filesystem, IDs, HTTP scheme detection, and input sanitization. +// These helpers are used by the HTTP handlers and startup code. import ( "crypto/sha256" "encoding/hex" "fmt" + "net/http" "os" "path/filepath" "strings" "time" ) +// MakeDir creates the directory if it does not already exist. func MakeDir(dirname string) error { if _, err := os.Stat(dirname); os.IsNotExist(err) { return os.Mkdir(dirname, 0666) @@ -17,6 +21,7 @@ func MakeDir(dirname string) error { return nil } +// DeleteDirContents removes all files and directories inside dirname but not the directory itself. func DeleteDirContents(dirname string) error { entries, err := os.ReadDir(dirname) if err != nil { @@ -38,6 +43,7 @@ func DeleteDirContents(dirname string) error { return nil } +// IsUsernameValid validates length, allowed characters, and rejects common placeholder names. func IsUsernameValid(username string) bool { minLength := 3 maxLength := 24 @@ -68,6 +74,7 @@ func IsUsernameValid(username string) bool { return true } +// GenerateId creates a short pseudo-unique ID derived from time and a hash. func GenerateId() string { timestamp := time.Now().UnixNano() randomComponent := time.Now().UnixNano() % 1000000 // Add some randomness @@ -76,6 +83,25 @@ func GenerateId() string { return hex.EncodeToString(hash[:])[:16] } +// GetScheme determines the request scheme, respecting TLS and X-Forwarded-Proto. +func GetScheme(r *http.Request) string { + if r.URL.Scheme != "" { + return r.URL.Scheme + } + + if r.TLS != nil { + return "https" + } + + proto := r.Header.Get("X-Forwarded-Proto") + if proto != "" { + return proto + } + + return "http" +} + +// SanitizeFilename removes path traversal characters and trims whitespace. func SanitizeFilename(filename string) string { filename = strings.ReplaceAll(filename, "/", "") filename = strings.ReplaceAll(filename, "\\", "")