radchat/main.go
2025-01-18 22:01:02 -06:00

628 lines
15 KiB
Go

package main
import (
"compress/gzip"
"database/sql"
"encoding/json"
"fmt"
_ "github.com/mattn/go-sqlite3"
"io"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
type Database struct {
db *sql.DB
}
func TimezoneToLocation(timezone string) *time.Location {
defaultLocation := time.FixedZone("UTC", 0)
location, err := time.LoadLocation(timezone)
if err != nil {
return defaultLocation
} else {
return location
}
}
func TimeStringToTime(timeString string) time.Time {
t, _ := time.Parse("2006-01-02 15:04:05", timeString)
return t
}
func TimeStringToTimeInLocation(timeString string, timezone string) string {
t := TimeStringToTime(timeString)
location := TimezoneToLocation(timezone)
return t.In(location).Format("2006-01-02 15:04:05")
}
func OpenDatabase(filepath string) *Database {
if db, err := sql.Open("sqlite3", filepath); err != nil {
return nil
} else {
return &Database{
db: db,
}
}
}
func (db *Database) Close() {
db.db.Close()
}
func (db *Database) DbCreateTableMessages() {
stmt := `CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip_address TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`
db.db.Exec(stmt)
}
func (db *Database) DbCreateTableUsers() {
stmt := `CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip_address TEXT NOT NULL,
username TEXT NOT NULL UNIQUE,
timezone TEXT DEFAULT 'America/New_York',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`
db.db.Exec(stmt)
}
func (db *Database) UserTimezoneSet(ip_address, timezone string) {
stmt, err := db.db.Prepare("UPDATE users SET timezone = ? WHERE ip_address = ?")
if err != nil {
fmt.Println(err)
}
stmt.Exec(timezone, ip_address)
}
func (db *Database) UserAdd(ip_address, username string) {
stmt, err := db.db.Prepare("INSERT INTO users (username, ip_address) VALUES (?, ?)")
if err != nil {
fmt.Println(err)
}
stmt.Exec(username, ip_address)
}
func (db *Database) MessageAdd(ip_address string, content string) {
stmt, err := db.db.Prepare("INSERT INTO messages (ip_address, content) VALUES (?, ?)")
if err != nil {
fmt.Println(err)
}
stmt.Exec(ip_address, content)
}
func (db *Database) UserNameGet(ip_address string) string {
rows, err := db.db.Query("SELECT username FROM users WHERE ip_address = ?", ip_address)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var username string
rows.Next()
rows.Scan(&username)
return username
}
func (db *Database) UserGetTimezone(ip_address string) string {
rows, err := db.db.Query("SELECT timezone FROM users WHERE ip_address = ?", ip_address)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var timezone string
rows.Next()
rows.Scan(&timezone)
return timezone
}
func (db *Database) UsersGet() []User {
rows, err := db.db.Query("SELECT * FROM users")
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var users []User
for rows.Next() {
var id string
var ip_address string
var username string
var created_at string
var timezone string
rows.Scan(&id, &ip_address, &username, &created_at, &timezone)
user := User{
Id: id,
Username: username,
IpAddress: ip_address,
Timezone: timezone,
}
users = append(users, user)
}
return users
}
func (db *Database) MessagesGet() []Message {
rows, err := db.db.Query(`
SELECT messages.id, messages.ip_address, messages.content,
strftime('%Y-%m-%d %H:%M:%S', messages.created_at) as created_at,
users.username
FROM messages
LEFT JOIN users ON messages.ip_address = users.ip_address;
`)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var id string
var content string
var ip_address string
var created_at string
var username string
rows.Scan(&id, &ip_address, &content, &created_at, &username)
message := Message{
Id: id,
Content: content,
SenderIp: ip_address,
SenderUsername: username,
Timestamp: created_at,
}
messages = append(messages, message)
}
return messages
}
func (db *Database) UserNameExists(username string) bool {
rows, err := db.db.Query("SELECT * FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
return rows.Next()
}
func (db *Database) UserExists(ip string) bool {
rows, err := db.db.Query("SELECT * FROM users WHERE ip_address = ?", ip)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
return rows.Next()
}
func (db *Database) UserNameChange(ip, newUsername string) {
stmt, err := db.db.Prepare("UPDATE users SET username = ? WHERE ip_address = ?")
if err != nil {
fmt.Println(err)
}
stmt.Exec(newUsername, ip)
}
func (db *Database) UserMessagesGet(ip string) []Message {
rows, err := db.db.Query(`
SELECT messages.*, users.username
FROM messages
LEFT JOIN users ON messages.ip_address = users.ip_address
WHERE messages.ip_address = ?
ORDER BY messages.created_at DESC;
`, ip)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var id string
var content string
var ip_address string
var created_at string
var username string
rows.Scan(&id, &ip_address, &content, &created_at, &username)
t, _ := time.Parse(created_at, created_at)
message := Message{
Id: id,
Content: content,
SenderIp: ip_address,
SenderUsername: username,
Timestamp: t.Format(created_at),
}
messages = append(messages, message)
}
return messages
}
func (db *Database) MessageDelete(id string) {
stmt, err := db.db.Prepare("DELETE FROM messages WHERE id = ?")
if err != nil {
fmt.Println(err)
}
stmt.Exec(id)
}
func (db *Database) UserDelete(ip string) {
stmt, err := db.db.Prepare("DELETE FROM users WHERE ip_address = ?")
if err != nil {
fmt.Println(err)
}
stmt.Exec(ip)
}
func (db *Database) UsersDelete() {
stmt, err := db.db.Prepare("DELETE FROM users")
if err != nil {
fmt.Println(err)
}
stmt.Exec()
}
func (db *Database) MessagesDelete() {
stmt, err := db.db.Prepare("DELETE FROM messages")
if err != nil {
fmt.Println(err)
}
stmt.Exec()
}
type gzipResponseWriter struct {
http.ResponseWriter
io.Writer
}
func (g *gzipResponseWriter) Write(data []byte) (int, error) {
return g.Writer.Write(data)
}
func GzipMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
gz := gzip.NewWriter(w)
defer gz.Close()
w.Header().Set("Content-Encoding", "gzip")
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
})
}
func getClientIP(r *http.Request) string {
if fwdIP := r.Header.Get("X-Forwarded-For"); fwdIP != "" {
return strings.Split(fwdIP, ",")[0]
}
clientIP := r.RemoteAddr
if host, _, err := net.SplitHostPort(clientIP); err == nil {
return host
}
return clientIP
}
type User struct {
Id string
Username string
IpAddress string
Timezone string
}
type Message struct {
Id string
Content string
SenderIp string
SenderUsername string
Timestamp string
}
type Server struct {
Ip string
Port int
Connected map[string]time.Time // Map IP -> Last activity time
Database *Database
mu sync.Mutex // For thread safety
}
func NewServer(ip string, port int, dbpath string) *Server {
return &Server{
Ip: ip,
Port: port,
Connected: make(map[string]time.Time),
Database: OpenDatabase(dbpath),
mu: sync.Mutex{},
}
}
func (s *Server) AddMessage(userip string, contents string) {
s.Database.MessageAdd(userip, contents)
}
func (s *Server) updateActivity(ip string) {
s.mu.Lock()
defer s.mu.Unlock()
s.Connected[ip] = time.Now()
}
func (s *Server) cleanupActivity() {
s.mu.Lock()
defer s.mu.Unlock()
for ip, lastActivity := range s.Connected {
if time.Since(lastActivity) > 10*time.Second {
delete(s.Connected, ip)
}
}
}
func (s *Server) handlePing(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
s.updateActivity(clientIP)
s.cleanupActivity()
w.WriteHeader(http.StatusOK)
}
func validUsername(username string) bool {
for _, c := range username {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return false
}
}
return true
}
func (s *Server) handleUsername(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
clientIP := getClientIP(r)
var req struct {
Username string `json:"username"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error": "Invalid JSON"}`, http.StatusBadRequest)
return
}
s.mu.Lock()
if len(req.Username) > 64 {
http.Error(w, fmt.Sprintf(`{"error": "Username too long (must be less than 64 characters)"}`), http.StatusRequestEntityTooLarge)
s.mu.Unlock()
return
}
if !validUsername(req.Username) {
http.Error(w, fmt.Sprintf(`{"error": "Username must only contain alphanumeric characters and/or underscores"}`), http.StatusBadRequest)
s.mu.Unlock()
return
}
if s.Database.UserNameExists(req.Username) {
s.mu.Unlock()
http.Error(w, fmt.Sprintf(`{"error": "Username already exists"}`), http.StatusConflict)
return
}
if s.Database.UserExists(clientIP) {
s.Database.UserNameChange(clientIP, req.Username)
} else {
s.Database.UserAdd(clientIP, req.Username)
}
s.mu.Unlock()
json.NewEncoder(w).Encode(map[string]string{"status": "Username registered"})
}
func getMessageTemplate(file string, body string) string {
contents, _ := os.ReadFile(file)
return strings.Replace(string(contents), "{{body}}", body, 1)
}
func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
w.Header().Set("Content-Type", "text/html")
var body string
messages := s.Database.MessagesGet()
for _, msg := range messages {
clientIP := getClientIP(r)
timeZone := s.Database.UserGetTimezone(clientIP)
timeLocal := TimeStringToTimeInLocation(msg.Timestamp, timeZone)
body += fmt.Sprintf(`<p><span class="username">%s </span><span class="timestamp">%s</span><br><span class="message">%s</span></p>`,
msg.SenderUsername, timeLocal, msg.Content)
}
w.Write([]byte(getMessageTemplate("messages.html", body)))
case http.MethodPut:
w.Header().Set("Content-Type", "application/json")
// Get client's IP
clientIP := getClientIP(r)
s.mu.Lock()
exists := s.Database.UserExists(clientIP)
username := s.Database.UserNameGet(clientIP)
s.mu.Unlock()
if !exists {
errorFmt := fmt.Sprintf(`{"error": "IP %s not registered with username"}`, clientIP)
http.Error(w, errorFmt, http.StatusUnauthorized)
return
}
var msg struct {
Message string `json:"message"`
}
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
http.Error(w, `{"error": "Invalid JSON"}`, http.StatusBadRequest)
return
}
s.Database.MessageAdd(clientIP, msg.Message)
json.NewEncoder(w).Encode(map[string]string{
"status": "Message received",
"from": username,
})
default:
http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
}
func (s *Server) handleUsernameStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
clientIP := getClientIP(r)
s.mu.Lock()
exists := s.Database.UserExists(clientIP)
username := s.Database.UserNameGet(clientIP)
s.mu.Unlock()
json.NewEncoder(w).Encode(map[string]interface{}{
"hasUsername": exists,
"username": username,
})
}
func (s *Server) handleUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
clientIP := getClientIP(r)
s.updateActivity(clientIP)
s.cleanupActivity()
s.mu.Lock()
var users []string
for ip := range s.Connected {
// for all connected, get their usernames
users = append(users, s.Database.UserNameGet(ip))
}
s.mu.Unlock()
json.NewEncoder(w).Encode(map[string]interface{}{
"users": users,
})
}
func (s *Server) handleTimezone(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
clientIP := getClientIP(r)
var req struct {
Timezone string `json:"timezone"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error": "Invalid JSON"}`, http.StatusBadRequest)
return
}
s.mu.Lock()
if !s.Database.UserExists(clientIP) {
http.Error(w, `{"error": "User not registered"}`, http.StatusUnauthorized)
s.mu.Unlock()
return
}
s.Database.UserTimezoneSet(clientIP, req.Timezone)
s.mu.Unlock()
json.NewEncoder(w).Encode(map[string]string{
"status": "success",
})
}
func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/html")
w.Write(readFile("root.html"))
}
func readFile(filepath string) []byte {
contents, _ := os.ReadFile(filepath)
return contents
}
func (s *Server) handleJs(w http.ResponseWriter, r *http.Request) {
_ = r
w.Header().Set("Content-Type", "application/javascript")
w.Write(readFile("root.js"))
}
func (s *Server) handleCss(w http.ResponseWriter, r *http.Request) {
_ = r
w.Header().Set("Content-Type", "text/css")
w.Write(readFile("root.css"))
}
func (s *Server) Run() {
handler := http.NewServeMux()
handler.HandleFunc("/ping", s.handlePing)
handler.HandleFunc("/username", s.handleUsername)
handler.HandleFunc("/messages", s.handleMessages)
handler.HandleFunc("/username/status", s.handleUsernameStatus)
handler.HandleFunc("/users", s.handleUsers)
handler.HandleFunc("/", s.handleRoot)
handler.HandleFunc("/root.js", s.handleJs)
handler.HandleFunc("/root.css", s.handleCss)
handler.HandleFunc("/timezone", s.handleTimezone)
fmt.Printf("Server starting on %s:%d\n", s.Ip, s.Port)
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", s.Ip, s.Port), GzipMiddleware(handler)); err != nil {
fmt.Printf("Server error: %v\n", err)
}
}
func (s *Server) Stop() {
s.Database.Close()
}
func main() {
ip := os.Args[1]
port, _ := strconv.Atoi(os.Args[2])
databaseFile := os.Args[3]
server := NewServer(ip, port, databaseFile)
server.Database.DbCreateTableMessages()
server.Database.DbCreateTableUsers()
defer server.Stop()
server.Run()
}