add scheme check and path escaping for urls

This commit is contained in:
Radon 2025-09-04 19:42:34 -05:00
parent 6618425c3d
commit 9d23223240
2 changed files with 45 additions and 9 deletions

28
hub.go
View File

@ -1,11 +1,21 @@
package main package main
// HTTP handlers and hub run loop integration for RadChat.
//
// This file contains:
// - WebSocket and HTTP endpoints (username check, user count, file upload/download)
// - Helpers to send messages to specific clients and broadcast user lists
// - The hub Run loop that manages registration, unregistration, and broadcasting
//
// The actual Hub/Client definitions live in the server package. Here we glue
// them to HTTP.
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"radchat/server" "radchat/server"
@ -13,6 +23,7 @@ import (
"time" "time"
) )
// IsUsernameTaken returns true if the given username already exists among connected clients (case-insensitive).
func IsUsernameTaken(h *server.Hub, username string) bool { func IsUsernameTaken(h *server.Hub, username string) bool {
h.Mutex.RLock() h.Mutex.RLock()
defer h.Mutex.RUnlock() defer h.Mutex.RUnlock()
@ -25,6 +36,7 @@ func IsUsernameTaken(h *server.Hub, username string) bool {
return false return false
} }
// HandleUserCountCheck responds with the current number of connected users.
func HandleUserCountCheck(hub *server.Hub, w http.ResponseWriter, r *http.Request) { func HandleUserCountCheck(hub *server.Hub, w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" { if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -41,6 +53,7 @@ func HandleUserCountCheck(hub *server.Hub, w http.ResponseWriter, r *http.Reques
} }
} }
// HandleUsernameCheck validates a requested username and checks for collisions.
func HandleUsernameCheck(hub *server.Hub, w http.ResponseWriter, r *http.Request) { func HandleUsernameCheck(hub *server.Hub, w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" { if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@ -171,7 +184,7 @@ func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration,
return return
} }
err := r.ParseMultipartForm(32 << 20) err := r.ParseMultipartForm(256 << 20)
if err != nil { if err != nil {
http.Error(w, "Error parsing form", http.StatusBadRequest) http.Error(w, "Error parsing form", http.StatusBadRequest)
return return
@ -199,21 +212,15 @@ func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration,
continue continue
} }
fmt.Println(fileHeader.Filename)
filename := SanitizeFilename(fileHeader.Filename) filename := SanitizeFilename(fileHeader.Filename)
fmt.Println(filename)
// get filename before the extension // get filename before the extension
fileExt := filepath.Ext(filename) fileExt := filepath.Ext(filename)
fmt.Println(fileExt)
fileName := strings.TrimSuffix(filename, fileExt) fileName := strings.TrimSuffix(filename, fileExt)
fmt.Println(fileName)
fileFullName := fileName + "_" + time.Now().Format("20060102150405") + fileExt fileFullName := fileName + "_" + time.Now().Format("20060102150405") + fileExt
fmt.Println(fileFullName)
filePath := filepath.Join(filesDir, fileFullName) filePath := filepath.Join(filesDir, fileFullName)
fmt.Println(filePath)
err = os.WriteFile(filePath, fileBytes, 0644) err = os.WriteFile(filePath, fileBytes, 0644)
if err != nil { if err != nil {
@ -223,8 +230,12 @@ func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration,
uploadedFiles = append(uploadedFiles, fileFullName) uploadedFiles = append(uploadedFiles, fileFullName)
fileFullName = url.PathEscape(fileFullName)
fileLocation := filepath.Join(r.URL.Path, fileFullName) fileLocation := filepath.Join(r.URL.Path, fileFullName)
fileLink := "https://" + filepath.Join(r.Host, fileLocation)
scheme := GetScheme(r)
fileLink := scheme + "://" + filepath.Join(r.Host, fileLocation)
// find our client and send a message as them // find our client and send a message as them
clients := h.Clients clients := h.Clients
@ -294,7 +305,6 @@ func HandleFileDownload(filesDir string, w http.ResponseWriter, r *http.Request)
} }
filePath := filepath.Join(filesDir, filename) filePath := filepath.Join(filesDir, filename)
fmt.Println(filePath)
absUploadsDir, err := filepath.Abs(filesDir) absUploadsDir, err := filepath.Abs(filesDir)
if err != nil { if err != nil {

View File

@ -1,15 +1,19 @@
package main package main
// Utilities for filesystem, IDs, HTTP scheme detection, and input sanitization.
// These helpers are used by the HTTP handlers and startup code.
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
) )
// MakeDir creates the directory if it does not already exist.
func MakeDir(dirname string) error { func MakeDir(dirname string) error {
if _, err := os.Stat(dirname); os.IsNotExist(err) { if _, err := os.Stat(dirname); os.IsNotExist(err) {
return os.Mkdir(dirname, 0666) return os.Mkdir(dirname, 0666)
@ -17,6 +21,7 @@ func MakeDir(dirname string) error {
return nil return nil
} }
// DeleteDirContents removes all files and directories inside dirname but not the directory itself.
func DeleteDirContents(dirname string) error { func DeleteDirContents(dirname string) error {
entries, err := os.ReadDir(dirname) entries, err := os.ReadDir(dirname)
if err != nil { if err != nil {
@ -38,6 +43,7 @@ func DeleteDirContents(dirname string) error {
return nil return nil
} }
// IsUsernameValid validates length, allowed characters, and rejects common placeholder names.
func IsUsernameValid(username string) bool { func IsUsernameValid(username string) bool {
minLength := 3 minLength := 3
maxLength := 24 maxLength := 24
@ -68,6 +74,7 @@ func IsUsernameValid(username string) bool {
return true return true
} }
// GenerateId creates a short pseudo-unique ID derived from time and a hash.
func GenerateId() string { func GenerateId() string {
timestamp := time.Now().UnixNano() timestamp := time.Now().UnixNano()
randomComponent := time.Now().UnixNano() % 1000000 // Add some randomness randomComponent := time.Now().UnixNano() % 1000000 // Add some randomness
@ -76,6 +83,25 @@ func GenerateId() string {
return hex.EncodeToString(hash[:])[:16] return hex.EncodeToString(hash[:])[:16]
} }
// GetScheme determines the request scheme, respecting TLS and X-Forwarded-Proto.
func GetScheme(r *http.Request) string {
if r.URL.Scheme != "" {
return r.URL.Scheme
}
if r.TLS != nil {
return "https"
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto != "" {
return proto
}
return "http"
}
// SanitizeFilename removes path traversal characters and trims whitespace.
func SanitizeFilename(filename string) string { func SanitizeFilename(filename string) string {
filename = strings.ReplaceAll(filename, "/", "") filename = strings.ReplaceAll(filename, "/", "")
filename = strings.ReplaceAll(filename, "\\", "") filename = strings.ReplaceAll(filename, "\\", "")