radchat/db/db.go

315 lines
6.9 KiB
Go

package db
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
"strconv"
"time"
)
type User struct {
Id string
Username string
Timezone string
HashedPassword string
}
type Message struct {
Id string
SenderUsername string
Content string
Timestamp string
Edited bool
}
type Database struct {
db *sql.DB
}
func OpenDatabase(filepath string) *Database {
if db, err := sql.Open("sqlite3", filepath); err != nil {
return nil
} else {
return &Database{
db: db,
}
}
}
func (db *Database) Close() {
db.db.Close()
}
func (db *Database) DbCreateTableMessages() {
stmt := `CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
content TEXT NOT NULL,
edited INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`
db.db.Exec(stmt)
}
func (db *Database) DbCreateTableUsers() {
stmt := `CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
hashed_password TEXT NOT NULL,
timezone TEXT DEFAULT 'America/New_York',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`
db.db.Exec(stmt)
}
func (db *Database) UserTimezoneSet(username, timezone string) {
_, err := db.db.Exec("UPDATE users SET timezone = ? WHERE username = ?", timezone, username)
if err != nil {
fmt.Println(err)
}
}
func (db *Database) UserAddWithPassword(username, unhashedPwd string) error {
// unhashedPwd can not be larger than 72 bytes
if len(unhashedPwd) > 72 {
return fmt.Errorf("Password too long")
}
hashedPwd, err := bcrypt.GenerateFromPassword([]byte(unhashedPwd), bcrypt.DefaultCost)
if err != nil {
return err
}
_, err = db.db.Exec("INSERT INTO users (username, hashed_password) VALUES (?, ?)", username, hashedPwd)
if err != nil {
return err
}
return nil
}
func (db *Database) UserPasswordCheck(username, unhashedPwd string) bool {
rows, err := db.db.Query("SELECT hashed_password FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var hashedPwd string
rows.Next()
rows.Scan(&hashedPwd)
err = bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(unhashedPwd))
return err == nil
}
func (db *Database) MessageAdd(username string, content string) {
_, err := db.db.Exec("INSERT INTO messages (username, content) VALUES (?, ?)", username, content)
if err != nil {
fmt.Println(err)
}
}
func (db *Database) UserGetTimezone(username string) string {
rows, err := db.db.Query("SELECT timezone FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var timezone string
rows.Next()
rows.Scan(&timezone)
return timezone
}
func (db *Database) UsersGet() []User {
rows, err := db.db.Query("SELECT * FROM users")
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var users []User
for rows.Next() {
var id string
var username string
var created_at string
var timezone string
rows.Scan(&id, &username, &created_at, &timezone)
user := User{
Id: id,
Username: username,
Timezone: timezone,
}
users = append(users, user)
}
return users
}
func (db *Database) MessagesGet() []Message {
rows, err := db.db.Query(`
SELECT
messages.id,
messages.username,
messages.content,
strftime('%Y-%m-%d %H:%M:%S', messages.created_at) as created_at,
messages.edited
FROM
messages
`)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var id string
var content string
var created_at string
var username string
var edited int
rows.Scan(&id, &username, &content, &created_at, &edited)
editedBool := false
if edited == 1 {
editedBool = true
}
message := Message{
Id: id,
Content: content,
SenderUsername: username,
Edited: editedBool,
Timestamp: created_at,
}
messages = append(messages, message)
}
return messages
}
func (db *Database) UserNameExists(username string) bool {
rows, err := db.db.Query("SELECT * FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
return rows.Next()
}
func (db *Database) UserExists(username string) bool {
rows, err := db.db.Query("SELECT * FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
return rows.Next()
}
func (db *Database) UserNameChange(oldUsername, newUsername string) {
_, err := db.db.Exec("UPDATE users SET username = ? WHERE username = ?", newUsername, oldUsername)
if err != nil {
fmt.Println(err)
}
}
func (db *Database) UserMessagesGet(username string) []Message {
rows, err := db.db.Query(`
SELECT messages.*, users.username
FROM messages
LEFT JOIN users ON messages.username = users.username
WHERE messages.username = ?
ORDER BY messages.created_at DESC;
`, username)
if err != nil {
fmt.Println(err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var id string
var content string
var created_at string
var username string
var edited int
rows.Scan(&id, &content, &created_at, &username, &edited)
t, _ := time.Parse(created_at, created_at)
editedBool := false
if edited == 1 {
editedBool = true
}
message := Message{
Id: id,
Content: content,
SenderUsername: username,
Edited: editedBool,
Timestamp: t.Format(created_at),
}
messages = append(messages, message)
}
return messages
}
func (db *Database) MessageDeleteId(id string) {
_, err := db.db.Exec("DELETE FROM messages WHERE id = ?", id)
if err != nil {
fmt.Println(err)
}
}
func (db *Database) MessageDeleteIfOwner(id string, username string) (int, error) {
res, err := db.db.Exec("DELETE FROM messages WHERE id = ? AND username = ?", id, username)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return int(affected), nil
}
func (db *Database) MessageEditIfOwner(id string, content string, username string) (int, error) {
res, err := db.db.Exec("UPDATE messages SET content = ?, edited = 1 WHERE id = ? AND username = ?", content, id, username)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return int(affected), nil
}
func (db *Database) DeleteOldMessages(ageMinutes int) {
if ageMinutes <= 0 {
return
}
age := strconv.Itoa(ageMinutes)
_, err := db.db.Exec("DELETE FROM messages WHERE created_at < datetime('now', ? || ' minutes')", "-"+age)
if err != nil {
fmt.Println(err)
}
}
func (db *Database) UserDelete(username string) {
_, err := db.db.Exec("DELETE FROM users WHERE username = ?", username)
if err != nil {
fmt.Println(err)
}
}
func (db *Database) UsersDelete() {
_, err := db.db.Exec("DELETE FROM users")
if err != nil {
fmt.Println(err)
}
}
func (db *Database) MessagesDelete() {
_, err := db.db.Exec("DELETE FROM messages")
if err != nil {
fmt.Println(err)
}
}