diff --git a/config.json b/config.json
index bb9b9f7..069d3bb 100644
--- a/config.json
+++ b/config.json
@@ -4,13 +4,16 @@
"port": 8080
},
"paths": {
- "databasePath": "/home/radon/Documents/chattest.db",
+ "databasePath": "/home/radon/Documents/chat.db",
"indexJsPath": "./public/index.js",
"indexCssPath": "./public/style.css",
- "indexHtmlPath": "./public/index.html"
+ "indexHtmlPath": "./public/index.html",
+ "signupHtmlPath": "./public/signup.html",
+ "loginHtmlPath": "./public/login.html"
},
"options": {
"messageMaxAge": 259200,
- "nameMaxLength": 32
+ "nameMaxLength": 32,
+ "messagesPerPage": 10
}
}
diff --git a/db/db.go b/db/db.go
index c443fbf..3f03678 100644
--- a/db/db.go
+++ b/db/db.go
@@ -4,20 +4,20 @@ import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
+ "golang.org/x/crypto/bcrypt"
"strconv"
"time"
)
type User struct {
- Id string
- Username string
- IpAddress string
- Timezone string
+ Id string
+ Username string
+ Timezone string
+ HashedPassword string
}
type Message struct {
Id string
- SenderIp string
SenderUsername string
Content string
Timestamp string
@@ -45,7 +45,7 @@ func (db *Database) Close() {
func (db *Database) DbCreateTableMessages() {
stmt := `CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
- ip_address TEXT NOT NULL,
+ username TEXT NOT NULL,
content TEXT NOT NULL,
edited INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
@@ -56,61 +56,59 @@ func (db *Database) DbCreateTableMessages() {
func (db *Database) DbCreateTableUsers() {
stmt := `CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
- ip_address TEXT NOT NULL,
username TEXT NOT NULL UNIQUE,
+ hashed_password TEXT NOT NULL,
timezone TEXT DEFAULT 'America/New_York',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`
db.db.Exec(stmt)
}
-func (db *Database) UserTimezoneSet(ip_address, timezone string) {
- _, err := db.db.Exec("UPDATE users SET timezone = ? WHERE ip_address = ?", timezone, ip_address)
+func (db *Database) UserTimezoneSet(username, timezone string) {
+ _, err := db.db.Exec("UPDATE users SET timezone = ? WHERE username = ?", timezone, username)
if err != nil {
fmt.Println(err)
}
}
-func (db *Database) UserAdd(ip_address, username string) {
- _, err := db.db.Exec("INSERT INTO users (username, ip_address) VALUES (?, ?)", username, ip_address)
- if err != nil {
- fmt.Println(err)
+func (db *Database) UserAddWithPassword(username, unhashedPwd string) error {
+ // unhashedPwd can not be larger than 72 bytes
+ if len(unhashedPwd) > 72 {
+ return fmt.Errorf("Password too long")
}
+ hashedPwd, err := bcrypt.GenerateFromPassword([]byte(unhashedPwd), bcrypt.DefaultCost)
+ if err != nil {
+ return err
+ }
+ _, err = db.db.Exec("INSERT INTO users (username, hashed_password) VALUES (?, ?)", username, hashedPwd)
+ if err != nil {
+ return err
+ }
+ return nil
}
-func (db *Database) MessageAdd(ip_address string, content string) {
- _, err := db.db.Exec("INSERT INTO messages (ip_address, content) VALUES (?, ?)", ip_address, content)
- if err != nil {
- fmt.Println(err)
- }
-}
-
-func (db *Database) UserNameGet(ip_address string) string {
- rows, err := db.db.Query("SELECT username FROM users WHERE ip_address = ?", ip_address)
+func (db *Database) UserPasswordCheck(username, unhashedPwd string) bool {
+ rows, err := db.db.Query("SELECT hashed_password FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
- var username string
+ var hashedPwd string
rows.Next()
- rows.Scan(&username)
- return username
+ rows.Scan(&hashedPwd)
+ err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(unhashedPwd))
+ return err == nil
}
-func (db *Database) UserIpGet(username string) string {
- rows, err := db.db.Query("SELECT ip_address FROM users WHERE username = ?", username)
+func (db *Database) MessageAdd(username string, content string) {
+ _, err := db.db.Exec("INSERT INTO messages (username, content) VALUES (?, ?)", username, content)
if err != nil {
fmt.Println(err)
}
- defer rows.Close()
- var ip_address string
- rows.Next()
- rows.Scan(&ip_address)
- return ip_address
}
-func (db *Database) UserGetTimezone(ip_address string) string {
- rows, err := db.db.Query("SELECT timezone FROM users WHERE ip_address = ?", ip_address)
+func (db *Database) UserGetTimezone(username string) string {
+ rows, err := db.db.Query("SELECT timezone FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
@@ -130,16 +128,14 @@ func (db *Database) UsersGet() []User {
var users []User
for rows.Next() {
var id string
- var ip_address string
var username string
var created_at string
var timezone string
- rows.Scan(&id, &ip_address, &username, &created_at, &timezone)
+ rows.Scan(&id, &username, &created_at, &timezone)
user := User{
- Id: id,
- Username: username,
- IpAddress: ip_address,
- Timezone: timezone,
+ Id: id,
+ Username: username,
+ Timezone: timezone,
}
users = append(users, user)
}
@@ -148,11 +144,14 @@ func (db *Database) UsersGet() []User {
func (db *Database) MessagesGet() []Message {
rows, err := db.db.Query(`
- SELECT messages.id, messages.ip_address, messages.content,
- strftime('%Y-%m-%d %H:%M:%S', messages.created_at) as created_at,
- users.username, messages.edited
- FROM messages
- LEFT JOIN users ON messages.ip_address = users.ip_address;
+ SELECT
+ messages.id,
+ messages.username,
+ messages.content,
+ strftime('%Y-%m-%d %H:%M:%S', messages.created_at) as created_at,
+ messages.edited
+ FROM
+ messages
`)
if err != nil {
fmt.Println(err)
@@ -164,11 +163,10 @@ func (db *Database) MessagesGet() []Message {
for rows.Next() {
var id string
var content string
- var ip_address string
var created_at string
var username string
var edited int
- rows.Scan(&id, &ip_address, &content, &created_at, &username, &edited)
+ rows.Scan(&id, &username, &content, &created_at, &edited)
editedBool := false
if edited == 1 {
@@ -178,7 +176,6 @@ func (db *Database) MessagesGet() []Message {
message := Message{
Id: id,
Content: content,
- SenderIp: ip_address,
SenderUsername: username,
Edited: editedBool,
Timestamp: created_at,
@@ -198,8 +195,8 @@ func (db *Database) UserNameExists(username string) bool {
return rows.Next()
}
-func (db *Database) UserExists(ip string) bool {
- rows, err := db.db.Query("SELECT * FROM users WHERE ip_address = ?", ip)
+func (db *Database) UserExists(username string) bool {
+ rows, err := db.db.Query("SELECT * FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
@@ -207,28 +204,21 @@ func (db *Database) UserExists(ip string) bool {
return rows.Next()
}
-func (db *Database) UserNameChange(ip, newUsername string) {
- _, err := db.db.Exec("UPDATE users SET username = ? WHERE ip_address = ?", newUsername, ip)
+func (db *Database) UserNameChange(oldUsername, newUsername string) {
+ _, err := db.db.Exec("UPDATE users SET username = ? WHERE username = ?", newUsername, oldUsername)
if err != nil {
fmt.Println(err)
}
}
-func (db *Database) UserMessagesDelete(ip string) {
- _, err := db.db.Exec("DELETE FROM messages WHERE ip_address = ?", ip)
- if err != nil {
- fmt.Println(err)
- }
-}
-
-func (db *Database) UserMessagesGet(ip string) []Message {
+func (db *Database) UserMessagesGet(username string) []Message {
rows, err := db.db.Query(`
SELECT messages.*, users.username
FROM messages
- LEFT JOIN users ON messages.ip_address = users.ip_address
- WHERE messages.ip_address = ?
+ LEFT JOIN users ON messages.username = users.username
+ WHERE messages.username = ?
ORDER BY messages.created_at DESC;
- `, ip)
+ `, username)
if err != nil {
fmt.Println(err)
}
@@ -239,11 +229,10 @@ func (db *Database) UserMessagesGet(ip string) []Message {
for rows.Next() {
var id string
var content string
- var ip_address string
var created_at string
var username string
var edited int
- rows.Scan(&id, &ip_address, &content, &created_at, &username, &edited)
+ rows.Scan(&id, &content, &created_at, &username, &edited)
t, _ := time.Parse(created_at, created_at)
editedBool := false
if edited == 1 {
@@ -252,7 +241,6 @@ func (db *Database) UserMessagesGet(ip string) []Message {
message := Message{
Id: id,
Content: content,
- SenderIp: ip_address,
SenderUsername: username,
Edited: editedBool,
Timestamp: t.Format(created_at),
@@ -268,8 +256,8 @@ func (db *Database) MessageDeleteId(id string) {
fmt.Println(err)
}
}
-func (db *Database) MessageDeleteIfOwner(id string, ip string) (int, error) {
- res, err := db.db.Exec("DELETE FROM messages WHERE id = ? AND ip_address = ?", id, ip)
+func (db *Database) MessageDeleteIfOwner(id string, username string) (int, error) {
+ res, err := db.db.Exec("DELETE FROM messages WHERE id = ? AND username = ?", id, username)
if err != nil {
return 0, err
}
@@ -281,8 +269,8 @@ func (db *Database) MessageDeleteIfOwner(id string, ip string) (int, error) {
}
-func (db *Database) MessageEditIfOwner(id string, content string, ip string) (int, error) {
- res, err := db.db.Exec("UPDATE messages SET content = ?, edited = 1 WHERE id = ? AND ip_address = ?", content, id, ip)
+func (db *Database) MessageEditIfOwner(id string, content string, username string) (int, error) {
+ res, err := db.db.Exec("UPDATE messages SET content = ?, edited = 1 WHERE id = ? AND username = ?", content, id, username)
if err != nil {
return 0, err
}
@@ -304,8 +292,8 @@ func (db *Database) DeleteOldMessages(ageMinutes int) {
}
}
-func (db *Database) UserDeleteIp(ip string) {
- _, err := db.db.Exec("DELETE FROM users WHERE ip_address = ?", ip)
+func (db *Database) UserDelete(username string) {
+ _, err := db.db.Exec("DELETE FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
diff --git a/go.mod b/go.mod
index d84ee9d..073c576 100644
--- a/go.mod
+++ b/go.mod
@@ -3,3 +3,5 @@ module chat
go 1.23.4
require github.com/mattn/go-sqlite3 v1.14.24 // direct
+
+require golang.org/x/crypto v0.32.0 // indirect
diff --git a/go.sum b/go.sum
index 9dcdc9b..b996447 100644
--- a/go.sum
+++ b/go.sum
@@ -1,2 +1,4 @@
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
+golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
+golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
diff --git a/public/index.js b/public/index.js
index 3f9fdaf..bec0ae4 100644
--- a/public/index.js
+++ b/public/index.js
@@ -433,19 +433,33 @@ async function checkUsername() {
const response = await fetch("/username/status");
const data = await response.json();
if (!data.hasUsername) {
- document.getElementById("settings-panel").style
- .display = "block";
- const username = document.getElementById("username");
- username.focus();
- username.selectionStart =
- username.selectionEnd =
- username.value.length;
+ // redirect to login page
+ window.location.href = "/login";
+ //
+ //
+ // document.getElementById("settings-panel").style
+ // .display = "block";
+ // const username = document.getElementById("username");
+ // username.focus();
+ // username.selectionStart =
+ // username.selectionEnd =
+ // username.value.length;
}
} catch (error) {
console.error("Error checking username status:", error);
}
}
+async function getCurrentUsername() {
+ try {
+ const response = await fetch("/username/status");
+ const data = await response.json();
+ return data.username;
+ } catch (error) {
+ console.error("Error getting username:", error);
+ }
+}
+
async function updateCurrentUser() {
try {
const response = await fetch("/username/status");
@@ -508,15 +522,15 @@ async function sendMessage() {
if (!message) {
return;
}
-
try {
lastMessage = message;
+ const username = await getCurrentUsername();
const response = await fetch("/messages", {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
- body: JSON.stringify({ message: message }),
+ body: JSON.stringify({ username: username, message: message }),
});
const data = await response.json();
diff --git a/public/login.html b/public/login.html
new file mode 100644
index 0000000..b22fdc6
--- /dev/null
+++ b/public/login.html
@@ -0,0 +1,178 @@
+
+
+
+
+
+
+
+
+
+
+ RadChat Login
+
+
+
+
+
+
+
+
+
+
+
diff --git a/public/signup.html b/public/signup.html
new file mode 100644
index 0000000..672fe2f
--- /dev/null
+++ b/public/signup.html
@@ -0,0 +1,186 @@
+
+
+
+
+
+
+
+
+
+
+ RadChat Signup
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/readme.md b/readme.md
index 44ba47e..ee23d8f 100644
--- a/readme.md
+++ b/readme.md
@@ -14,6 +14,13 @@
- Lazy load with pagination (frontend and backend)
- Add live voice chat? (This will be fun, maybe a separate app)
### Mid Priority
-- Add actual logging instead of ip based usernames, have messages tied to the logged in user not an ip (db changes)
+- NEW LOGIN STUFF
+ - IN PROGRESS: Add actual logging instead of ip based usernames, have messages tied to the logged in user not an ip (db changes)
+ * Fix editing messages
+ * Fix deleting messages
+ * Fix changing username
+ * Fix CSS for signin page
+ * Add logout button to settings, should touch go, js all that, logout request
+ * Fix CSS for login page
### Low Priority
- Nothing yet
diff --git a/srv/handle.go b/srv/handle.go
index 94daa9e..b912589 100644
--- a/srv/handle.go
+++ b/srv/handle.go
@@ -4,9 +4,21 @@ import (
tu "chat/tu"
"encoding/json"
"fmt"
+ "math/rand"
"net/http"
)
+func generateName() string {
+ adjectives := []string{"Unrelenting", "Mystical", "Radiant", "Curious", "Peaceful", "Ancient", "Wandering", "Silent", "Celestial", "Dancing", "Eternal", "Resolute", "Whispering", "Serene", "Wild"}
+ colors := []string{"Purple", "Azure", "Crimson", "Golden", "Emerald", "Sapphire", "Obsidian", "Silver", "Amber", "Jade", "Indigo", "Violet", "Cerulean", "Copper", "Pearl"}
+ nouns := []string{"Elephant", "Phoenix", "Dragon", "Warrior", "Spirit", "Tiger", "Raven", "Mountain", "River", "Storm", "Falcon", "Wolf", "Ocean", "Star", "Moon"}
+
+ return fmt.Sprintf("%s-%s-%s",
+ adjectives[rand.Intn(len(adjectives))],
+ colors[rand.Intn(len(colors))],
+ nouns[rand.Intn(len(nouns))])
+}
+
func (s *Server) handlePing(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
s.updateActivity(clientIP)
@@ -33,35 +45,32 @@ func (s *Server) handleUsername(w http.ResponseWriter, r *http.Request) {
return
}
- s.mu.Lock()
-
if len(req.Username) > s.Config.Options.NameMaxLength {
http.Error(w, fmt.Sprintf(`{"error": "Username too long (%v out of %v characters maximum)"}`, len(req.Username), s.Config.Options.NameMaxLength), http.StatusRequestEntityTooLarge)
- s.mu.Unlock()
return
}
if !validUsername(req.Username) {
http.Error(w, fmt.Sprintf(`{"error": "Username must only contain alphanumeric characters and/or underscores"}`), http.StatusBadRequest)
- s.mu.Unlock()
return
}
if s.Database.UserNameExists(req.Username) {
http.Error(w, fmt.Sprintf(`{"error": "Username already exists"}`), http.StatusConflict)
- s.mu.Unlock()
return
}
- if s.Database.UserExists(clientIP) {
- s.Database.UserNameChange(clientIP, req.Username)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if username, ok := s.LoggedIn[clientIP]; ok {
+ s.LogUserOut(username)
+ s.Database.UserNameChange(username, req.Username)
+ s.LogUserIn(clientIP, req.Username)
+ json.NewEncoder(w).Encode(map[string]string{"status": "Username changed"})
} else {
- s.Database.UserAdd(clientIP, req.Username)
+ http.Error(w, `{"error": "Failure to change username"}`, http.StatusUnauthorized)
}
- s.mu.Unlock()
-
- json.NewEncoder(w).Encode(map[string]string{"status": "Username registered"})
}
func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
@@ -97,7 +106,11 @@ func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
messages := s.Database.MessagesGet()
for _, msg := range messages {
clientIP := getClientIP(r)
- timeZone := s.Database.UserGetTimezone(clientIP)
+ username, ok := s.LoggedIn[clientIP]
+ timeZone := "UTC"
+ if ok {
+ timeZone = s.Database.UserGetTimezone(username)
+ }
timeLocal := tu.TimeStringToTimeInLocation(msg.Timestamp, timeZone)
edited := ""
if msg.Edited {
@@ -111,22 +124,9 @@ func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
case http.MethodPut:
w.Header().Set("Content-Type", "application/json")
- // Get client's IP
- clientIP := getClientIP(r)
-
- s.mu.Lock()
- exists := s.Database.UserExists(clientIP)
- username := s.Database.UserNameGet(clientIP)
- s.mu.Unlock()
-
- if !exists {
- errorFmt := fmt.Sprintf(`{"error": "IP %s not registered with username"}`, clientIP)
- http.Error(w, errorFmt, http.StatusUnauthorized)
- return
- }
-
var msg struct {
- Message string `json:"message"`
+ Username string `json:"username"`
+ Message string `json:"message"`
}
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
@@ -134,12 +134,19 @@ func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
return
}
- s.Database.MessageAdd(clientIP, msg.Message)
+ clientIP := getClientIP(r)
+ if username, ok := s.LoggedIn[clientIP]; ok {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.Database.MessageAdd(username, msg.Message)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "Message received",
+ "from": username,
+ })
+ } else {
+ http.Error(w, `{"error": "Unauthorized"}`, http.StatusUnauthorized)
+ }
- json.NewEncoder(w).Encode(map[string]string{
- "status": "Message received",
- "from": username,
- })
case http.MethodDelete:
w.Header().Set("Content-Type", "application/json")
@@ -171,21 +178,16 @@ func (s *Server) handleMessages(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleUsernameStatus(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
- return
- }
+ // if r.Method != http.MethodGet {
+ // http.Error(w, `{"error": "Method not allowed"}`, http.StatusMethodNotAllowed)
+ // return
+ // }
+ clientIP := getClientIP(r)
+ username, ok := s.LoggedIn[clientIP]
w.Header().Set("Content-Type", "application/json")
- clientIP := getClientIP(r)
-
- s.mu.Lock()
- exists := s.Database.UserExists(clientIP)
- username := s.Database.UserNameGet(clientIP)
- s.mu.Unlock()
-
json.NewEncoder(w).Encode(map[string]interface{}{
- "hasUsername": exists,
+ "hasUsername": ok,
"username": username,
})
}
@@ -202,12 +204,14 @@ func (s *Server) handleUsers(w http.ResponseWriter, r *http.Request) {
s.updateActivity(clientIP)
s.cleanupActivity()
s.mu.Lock()
+ defer s.mu.Unlock()
var users []string
for ip := range s.Connected {
// for all connected, get their usernames
- users = append(users, s.Database.UserNameGet(ip))
+ if username, ok := s.LoggedIn[ip]; ok {
+ users = append(users, username)
+ }
}
- s.mu.Unlock()
json.NewEncoder(w).Encode(map[string]interface{}{
"users": users,
@@ -233,18 +237,15 @@ func (s *Server) handleTimezone(w http.ResponseWriter, r *http.Request) {
}
s.mu.Lock()
-
- if !s.Database.UserExists(clientIP) {
+ defer s.mu.Unlock()
+ if username, ok := s.LoggedIn[clientIP]; ok {
+ s.Database.UserTimezoneSet(username, req.Timezone)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "success",
+ })
+ } else {
http.Error(w, `{"error": "User not registered"}`, http.StatusUnauthorized)
- s.mu.Unlock()
- return
}
-
- s.Database.UserTimezoneSet(clientIP, req.Timezone)
- s.mu.Unlock()
- json.NewEncoder(w).Encode(map[string]string{
- "status": "success",
- })
}
func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
@@ -283,3 +284,90 @@ func (s *Server) handleMessagesLength(w http.ResponseWriter, r *http.Request) {
"length": len(messages),
})
}
+
+func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ w.Header().Set("Content-Type", "text/html")
+ file := readFile(s.Config.Paths.LoginHtmlPath)
+ w.Write(file)
+ return
+ case http.MethodPost:
+ w.Header().Set("Content-Type", "application/json")
+ var req struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, `{"error": "Invalid JSON"}`, http.StatusBadRequest)
+ return
+ }
+
+ validLogin := s.Database.UserPasswordCheck(req.Username, req.Password)
+ if !validLogin {
+ http.Error(w, `{"error": "Invalid username or password"}`, http.StatusUnauthorized)
+ return
+ } else {
+ clientIP := getClientIP(r)
+ s.LogUserIn(clientIP, req.Username)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "Logged in",
+ })
+ }
+ }
+}
+
+func (s *Server) handleSignup(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ w.Header().Set("Content-Type", "text/html")
+ file := readFile(s.Config.Paths.SignupHtmlPath)
+ w.Write(file)
+ return
+ case http.MethodPost:
+ w.Header().Set("Content-Type", "application/json")
+ var req struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+ }
+
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, `{"error": "Invalid JSON"}`, http.StatusBadRequest)
+ return
+ }
+
+ // validate username length
+ if len(req.Username) > s.Config.Options.NameMaxLength {
+ http.Error(w, fmt.Sprintf(`{"error": "Username too long (%v out of %v characters maximum)"}`, len(req.Username), s.Config.Options.NameMaxLength), http.StatusRequestEntityTooLarge)
+ return
+ }
+
+ // validate username
+ if !validUsername(req.Username) {
+ http.Error(w, fmt.Sprintf(`{"error": "Username must only contain alphanumeric characters and/or underscores"}`), http.StatusBadRequest)
+ return
+ }
+
+ // validate user doesnt already exist
+ if s.Database.UserNameExists(req.Username) {
+ http.Error(w, fmt.Sprintf(`{"error": "Username already exists"}`), http.StatusConflict)
+ return
+ }
+
+ // add user to database with hashedpassword
+ err := s.Database.UserAddWithPassword(req.Username, req.Password)
+ if err != nil {
+ fmt.Println("Database error while signing up a new user")
+ http.Error(w, fmt.Sprintf(`{"error": "%v"}`, err), http.StatusInternalServerError)
+ return
+ }
+
+ // log user in
+ clientIP := getClientIP(r)
+ s.LogUserIn(clientIP, req.Username)
+
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "User created",
+ })
+ }
+}
diff --git a/srv/srv.go b/srv/srv.go
index 6226f98..437fd51 100644
--- a/srv/srv.go
+++ b/srv/srv.go
@@ -21,15 +21,17 @@ type Config struct {
Port int `json:"port"`
} `json:"server"`
Paths struct {
- DatabasePath string `json:"databasePath"`
- IndexJsPath string `json:"indexJsPath"`
- IndexCssPath string `json:"indexCssPath"`
- IndexHtmlPath string `json:"indexHtmlPath"`
- MessagesHtmlPath string `json:"messagesHtmlPath"`
+ DatabasePath string `json:"databasePath"`
+ IndexJsPath string `json:"indexJsPath"`
+ IndexCssPath string `json:"indexCssPath"`
+ IndexHtmlPath string `json:"indexHtmlPath"`
+ SignupHtmlPath string `json:"signupHtmlPath"`
+ LoginHtmlPath string `json:"loginHtmlPath"`
} `json:"paths"`
Options struct {
- MessageMaxAge int `json:"messageMaxAge"`
- NameMaxLength int `json:"nameMaxLength"`
+ MessageMaxAge int `json:"messageMaxAge"`
+ NameMaxLength int `json:"nameMaxLength"`
+ MessagePerPage int `json:"messagePerPage"`
} `json:"options"`
}
@@ -40,7 +42,8 @@ func LoadConfig(filepath string) Config {
config.Paths.IndexHtmlPath = pathMaker(config.Paths.IndexHtmlPath)
config.Paths.IndexJsPath = pathMaker(config.Paths.IndexJsPath)
config.Paths.IndexCssPath = pathMaker(config.Paths.IndexCssPath)
- config.Paths.MessagesHtmlPath = pathMaker(config.Paths.MessagesHtmlPath)
+ config.Paths.SignupHtmlPath = pathMaker(config.Paths.SignupHtmlPath)
+ config.Paths.LoginHtmlPath = pathMaker(config.Paths.LoginHtmlPath)
config.Paths.DatabasePath = pathMaker(config.Paths.DatabasePath)
if err != nil {
fmt.Println("Error parsing config file: ", err)
@@ -118,6 +121,7 @@ func validUsername(username string) bool {
}
type Server struct {
+ LoggedIn map[string]string // Map Username -> IP
Connected map[string]time.Time // Map IP -> Last activity time
Database *db.Database
Config Config
@@ -127,6 +131,7 @@ type Server struct {
func NewServer(config Config) *Server {
return &Server{
+ LoggedIn: make(map[string]string),
Connected: make(map[string]time.Time),
Database: db.OpenDatabase(config.Paths.DatabasePath),
Config: config,
@@ -138,6 +143,23 @@ func (s *Server) AddMessage(userip string, contents string) {
s.Database.MessageAdd(userip, contents)
}
+func (s *Server) LogUserIn(ip, username string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.LoggedIn[ip] = username
+}
+
+func (s *Server) LogUserOut(username string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for ip, u := range s.LoggedIn {
+ if u == username {
+ delete(s.LoggedIn, ip)
+ }
+ }
+ delete(s.LoggedIn, username)
+}
+
func (s *Server) updateActivity(ip string) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -150,6 +172,7 @@ func (s *Server) cleanupActivity() {
for ip, lastActivity := range s.Connected {
if time.Since(lastActivity) > 10*time.Second {
delete(s.Connected, ip)
+ s.LogUserOut(s.LoggedIn[ip])
}
}
}
@@ -169,6 +192,8 @@ func (s *Server) Run() {
handler.HandleFunc("/users", s.handleUsers)
handler.HandleFunc("/username", s.handleUsername)
handler.HandleFunc("/messages", s.handleMessages)
+ handler.HandleFunc("/signup", s.handleSignup)
+ handler.HandleFunc("/login", s.handleLogin)
fmt.Printf("Server starting on %s:%d\n", s.Config.Server.IpAddress, s.Config.Server.Port)
defer s.Stop()
if err := http.ListenAndServe(fmt.Sprintf("%s:%d", s.Config.Server.IpAddress, s.Config.Server.Port), GzipMiddleware(handler)); err != nil {