package db import ( "database/sql" "fmt" _ "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" "strconv" "time" ) type User struct { Id string Username string Timezone string HashedPassword string } type Message struct { Id string SenderUsername string Content string Timestamp string Edited bool } type Database struct { db *sql.DB } func OpenDatabase(filepath string) *Database { if db, err := sql.Open("sqlite3", filepath); err != nil { return nil } else { return &Database{ db: db, } } } func (db *Database) Close() { db.db.Close() } func (db *Database) DbCreateTableMessages() { stmt := `CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, content TEXT NOT NULL, edited INTEGER DEFAULT 0, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )` db.db.Exec(stmt) } func (db *Database) DbCreateTableUsers() { stmt := `CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL UNIQUE, hashed_password TEXT NOT NULL, timezone TEXT DEFAULT 'America/New_York', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )` db.db.Exec(stmt) } func (db *Database) UserTimezoneSet(username, timezone string) { _, err := db.db.Exec("UPDATE users SET timezone = ? WHERE username = ?", timezone, username) if err != nil { fmt.Println(err) } } func (db *Database) UserAddWithPassword(username, unhashedPwd string) error { // unhashedPwd can not be larger than 72 bytes if len(unhashedPwd) > 72 { return fmt.Errorf("Password too long") } hashedPwd, err := bcrypt.GenerateFromPassword([]byte(unhashedPwd), bcrypt.DefaultCost) if err != nil { return err } _, err = db.db.Exec("INSERT INTO users (username, hashed_password) VALUES (?, ?)", username, hashedPwd) if err != nil { return err } return nil } func (db *Database) UserPasswordCheck(username, unhashedPwd string) bool { rows, err := db.db.Query("SELECT hashed_password FROM users WHERE username = ?", username) if err != nil { fmt.Println(err) } defer rows.Close() var hashedPwd string rows.Next() rows.Scan(&hashedPwd) err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(unhashedPwd)) return err == nil } func (db *Database) MessageAdd(username string, content string) { _, err := db.db.Exec("INSERT INTO messages (username, content) VALUES (?, ?)", username, content) if err != nil { fmt.Println(err) } } func (db *Database) UserGetTimezone(username string) string { rows, err := db.db.Query("SELECT timezone FROM users WHERE username = ?", username) if err != nil { fmt.Println(err) } defer rows.Close() var timezone string rows.Next() rows.Scan(&timezone) return timezone } func (db *Database) UsersGet() []User { rows, err := db.db.Query("SELECT * FROM users") if err != nil { fmt.Println(err) } defer rows.Close() var users []User for rows.Next() { var id string var username string var created_at string var timezone string rows.Scan(&id, &username, &created_at, &timezone) user := User{ Id: id, Username: username, Timezone: timezone, } users = append(users, user) } return users } func (db *Database) MessagesGet() []Message { rows, err := db.db.Query(` SELECT messages.id, messages.username, messages.content, strftime('%Y-%m-%d %H:%M:%S', messages.created_at) as created_at, messages.edited FROM messages `) if err != nil { fmt.Println(err) } defer rows.Close() var messages []Message for rows.Next() { var id string var content string var created_at string var username string var edited int rows.Scan(&id, &username, &content, &created_at, &edited) editedBool := false if edited == 1 { editedBool = true } message := Message{ Id: id, Content: content, SenderUsername: username, Edited: editedBool, Timestamp: created_at, } messages = append(messages, message) } return messages } func (db *Database) UserNameExists(username string) bool { rows, err := db.db.Query("SELECT * FROM users WHERE username = ?", username) if err != nil { fmt.Println(err) } defer rows.Close() return rows.Next() } func (db *Database) UserExists(username string) bool { rows, err := db.db.Query("SELECT * FROM users WHERE username = ?", username) if err != nil { fmt.Println(err) } defer rows.Close() return rows.Next() } func (db *Database) UserNameChange(oldUsername, newUsername string) { _, err := db.db.Exec("UPDATE users SET username = ? WHERE username = ?", newUsername, oldUsername) if err != nil { fmt.Println(err) } } func (db *Database) UserMessagesGet(username string) []Message { rows, err := db.db.Query(` SELECT messages.*, users.username FROM messages LEFT JOIN users ON messages.username = users.username WHERE messages.username = ? ORDER BY messages.created_at DESC; `, username) if err != nil { fmt.Println(err) } defer rows.Close() var messages []Message for rows.Next() { var id string var content string var created_at string var username string var edited int rows.Scan(&id, &content, &created_at, &username, &edited) t, _ := time.Parse(created_at, created_at) editedBool := false if edited == 1 { editedBool = true } message := Message{ Id: id, Content: content, SenderUsername: username, Edited: editedBool, Timestamp: t.Format(created_at), } messages = append(messages, message) } return messages } func (db *Database) MessageDeleteId(id string) { _, err := db.db.Exec("DELETE FROM messages WHERE id = ?", id) if err != nil { fmt.Println(err) } } func (db *Database) MessageDeleteIfOwner(id string, username string) (int, error) { res, err := db.db.Exec("DELETE FROM messages WHERE id = ? AND username = ?", id, username) if err != nil { return 0, err } affected, err := res.RowsAffected() if err != nil { return 0, err } return int(affected), nil } func (db *Database) MessageEditIfOwner(id string, content string, username string) (int, error) { res, err := db.db.Exec("UPDATE messages SET content = ?, edited = 1 WHERE id = ? AND username = ?", content, id, username) if err != nil { return 0, err } affected, err := res.RowsAffected() if err != nil { return 0, err } return int(affected), nil } func (db *Database) DeleteOldMessages(ageMinutes int) { if ageMinutes <= 0 { return } age := strconv.Itoa(ageMinutes) _, err := db.db.Exec("DELETE FROM messages WHERE created_at < datetime('now', ? || ' minutes')", "-"+age) if err != nil { fmt.Println(err) } } func (db *Database) UserDelete(username string) { _, err := db.db.Exec("DELETE FROM users WHERE username = ?", username) if err != nil { fmt.Println(err) } } func (db *Database) UsersDelete() { _, err := db.db.Exec("DELETE FROM users") if err != nil { fmt.Println(err) } } func (db *Database) MessagesDelete() { _, err := db.db.Exec("DELETE FROM messages") if err != nil { fmt.Println(err) } }