add file upload/download

This commit is contained in:
Radon 2025-09-03 20:41:17 -05:00
parent c9c5fff712
commit 22c3996527
3 changed files with 163 additions and 62 deletions

210
hub.go
View File

@ -3,12 +3,14 @@ package main
import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"radchat/server"
"strings"
"time"
)
func IsUsernameTaken(h *server.Hub, username string) bool {
@ -147,64 +149,160 @@ func HandleWebSocket(hub *server.Hub, w http.ResponseWriter, r *http.Request) {
go ReadPump(client)
}
func HandleUserFiles(h *server.Hub, uploadsDir string, w http.ResponseWriter, r *http.Request) {
// get will essentially serve the uploads folder as a static folder so someone can get something from it
// put/post will allow you to upload files to the uploads folder
_ = h
switch r.Method {
case "POST":
fallthrough
case "PUT":
// TODO: Implement uploading!
// upload files to the uploads folder
case "GET":
path := strings.TrimPrefix(r.URL.Path, "/")
segments := strings.Split(path, "/")
if len(segments) < 2 {
http.Error(w, "Filename required", http.StatusBadRequest)
return
}
filename := segments[len(segments)-1]
filename = SanitizeFilename(filename)
if filename == "" {
http.Error(w, "Invalid filename", http.StatusBadRequest)
return
}
filePath := filepath.Join(uploadsDir, filename)
fmt.Println(filePath)
absUploadsDir, err := filepath.Abs(uploadsDir)
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
if !strings.HasPrefix(absFilePath, absUploadsDir) {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
return
}
http.ServeFile(w, r, filePath)
func HandleFileUpload(h *server.Hub, filesDir string, fileTimeout time.Duration, w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse multipart form (50MB max memory)
err := r.ParseMultipartForm(100 << 20)
if err != nil {
http.Error(w, "Error parsing form", http.StatusBadRequest)
return
}
// Get all files
files := r.MultipartForm.File
_ = r.ParseForm()
username := r.FormValue("username")
clientId := r.FormValue("client_id")
var uploadedFiles []string
for _, fileHeaders := range files {
for _, fileHeader := range fileHeaders {
file, err := fileHeader.Open()
if err != nil {
continue
}
fileBytes, err := io.ReadAll(file)
if err != nil {
_ = file.Close()
continue
}
fmt.Println(fileHeader.Filename)
filename := SanitizeFilename(fileHeader.Filename)
fmt.Println(filename)
// get filename before the extension
fileExt := filepath.Ext(filename)
fmt.Println(fileExt)
fileName := strings.TrimSuffix(filename, fileExt)
fmt.Println(fileName)
fileFullName := fileName + "_" + time.Now().Format("20060102150405") + fileExt
fmt.Println(fileFullName)
filePath := filepath.Join(filesDir, fileFullName)
fmt.Println(filePath)
err = os.WriteFile(filePath, fileBytes, 0644)
if err != nil {
_ = file.Close()
continue
}
uploadedFiles = append(uploadedFiles, fileFullName)
fileLocation := filepath.Join(r.URL.Path, fileFullName)
fileLink := "https://" + filepath.Join(r.Host, fileLocation)
// find our client and send a message as them
clients := h.Clients
for client := range clients {
if client.Id == clientId {
broadcastMsg := Message{
Type: "chat_message",
Username: username,
Data: map[string]any{
"message": fileLink,
"timestamp": time.Now().UnixMilli(),
},
}
data, _ := json.Marshal(broadcastMsg)
h.Broadcast <- data
break
}
}
go func() {
time.Sleep(fileTimeout)
err := os.Remove(filename)
if err != nil {
return
}
}()
_ = file.Close()
}
}
// Respond with success
response := map[string]interface{}{
"success": true,
"uploaded_files": uploadedFiles,
"count": len(uploadedFiles),
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(response)
if err != nil {
return
}
}
func HandleFileDownload(filesDir string, w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
path := strings.TrimPrefix(r.URL.Path, "/")
segments := strings.Split(path, "/")
if len(segments) < 2 {
http.Error(w, "Filename required", http.StatusBadRequest)
return
}
filename := segments[len(segments)-1]
filename = SanitizeFilename(filename)
if filename == "" {
http.Error(w, "Invalid filename", http.StatusBadRequest)
return
}
filePath := filepath.Join(filesDir, filename)
fmt.Println(filePath)
absUploadsDir, err := filepath.Abs(filesDir)
if err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
absFilePath, err := filepath.Abs(filePath)
if err != nil {
http.Error(w, "Invalid file path", http.StatusBadRequest)
return
}
if !strings.HasPrefix(absFilePath, absUploadsDir) {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
http.Error(w, "File not found", http.StatusNotFound)
return
}
http.ServeFile(w, r, filePath)
}
func Run(h *server.Hub) {

12
main.go
View File

@ -7,6 +7,7 @@ import (
"net/http"
"path/filepath"
"radchat/server"
"time"
)
func main() {
@ -16,6 +17,7 @@ func main() {
var gzipEnabled = flag.Bool("gzip-enable", false, "Enable gzip compression")
var cachingDisabled = flag.Bool("cache-disable", false, "Disable caching")
var filesDirectory = flag.String("files", "./files", "Directory to store upload files")
var filesTimeout = flag.Int("files-timeout", 3600, "File timeout in seconds")
var origin = flag.String("origin", "", "Origin to allow (e.g. example.com), leave blank to allow all")
var help = flag.Bool("help", false, "Show help")
@ -26,9 +28,9 @@ func main() {
return
}
filesDir, err := filepath.Abs(*filesDirectory)
filesDir := filepath.Clean(*filesDirectory)
_ = DeleteDirContents(filesDir)
err = MakeDir(filesDir)
err := MakeDir(filesDir)
if err != nil {
log.Fatal(err)
}
@ -40,8 +42,12 @@ func main() {
mux := http.NewServeMux()
mux.HandleFunc("/files", func(w http.ResponseWriter, r *http.Request) {
HandleFileUpload(hub, filesDir, time.Duration(*filesTimeout)*time.Second, w, r)
})
mux.HandleFunc("/files/", func(w http.ResponseWriter, r *http.Request) {
HandleUserFiles(hub, filesDir, w, r)
HandleFileDownload(filesDir, w, r)
})
mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {

View File

@ -82,8 +82,5 @@ func SanitizeFilename(filename string) string {
filename = strings.ReplaceAll(filename, "..", "")
filename = strings.TrimSpace(filename)
if strings.HasPrefix(filename, ".") || filename == "" {
return ""
}
return filename
}