188 lines
5.6 KiB
Go
188 lines
5.6 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
_ "github.com/lib/pq"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type DatabaseConfig struct {
|
|
Type string
|
|
Host string
|
|
Port string
|
|
User string
|
|
Password string
|
|
Name string
|
|
Path string
|
|
SSLMode string
|
|
}
|
|
|
|
func InitializeDatabase(exPath, dataDirFlag string) (*sqlx.DB, error) {
|
|
config := getDatabaseConfig(exPath, dataDirFlag)
|
|
|
|
if config.Type == "postgres" {
|
|
return initializePostgres(config)
|
|
}
|
|
return initializeSQLite(config)
|
|
}
|
|
|
|
func getDatabaseConfig(exPath, dataDirFlag string) DatabaseConfig {
|
|
dbUser := os.Getenv("DB_USER")
|
|
dbPassword := os.Getenv("DB_PASSWORD")
|
|
dbName := os.Getenv("DB_NAME")
|
|
dbHost := os.Getenv("DB_HOST")
|
|
dbPort := os.Getenv("DB_PORT")
|
|
dbSSL := os.Getenv("DB_SSLMODE")
|
|
|
|
sslMode := dbSSL
|
|
if dbSSL == "true" {
|
|
sslMode = "require"
|
|
} else if dbSSL == "false" || dbSSL == "" {
|
|
sslMode = "disable"
|
|
}
|
|
|
|
if dbUser != "" && dbPassword != "" && dbName != "" && dbHost != "" && dbPort != "" {
|
|
return DatabaseConfig{
|
|
Type: "postgres",
|
|
Host: dbHost,
|
|
Port: dbPort,
|
|
User: dbUser,
|
|
Password: dbPassword,
|
|
Name: dbName,
|
|
SSLMode: sslMode,
|
|
}
|
|
}
|
|
|
|
// Use datadir flag if provided, otherwise fall back to executable directory
|
|
dataPath := exPath
|
|
if dataDirFlag != "" {
|
|
dataPath = dataDirFlag
|
|
}
|
|
|
|
return DatabaseConfig{
|
|
Type: "sqlite",
|
|
Path: filepath.Join(dataPath, "dbdata"),
|
|
}
|
|
}
|
|
|
|
func initializePostgres(config DatabaseConfig) (*sqlx.DB, error) {
|
|
dsn := fmt.Sprintf(
|
|
"user=%s password=%s dbname=%s host=%s port=%s sslmode=%s",
|
|
config.User, config.Password, config.Name, config.Host, config.Port, config.SSLMode,
|
|
)
|
|
|
|
db, err := sqlx.Open("postgres", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open postgres connection: %w", err)
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to ping postgres database: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func initializeSQLite(config DatabaseConfig) (*sqlx.DB, error) {
|
|
if err := os.MkdirAll(config.Path, 0751); err != nil {
|
|
return nil, fmt.Errorf("could not create dbdata directory: %w", err)
|
|
}
|
|
|
|
dbPath := filepath.Join(config.Path, "users.db")
|
|
db, err := sqlx.Open("sqlite", dbPath+"?_pragma=foreign_keys(1)&_busy_timeout=3000")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open sqlite database: %w", err)
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, fmt.Errorf("failed to ping sqlite database: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
type HistoryMessage struct {
|
|
ID int `json:"id" db:"id"`
|
|
UserID string `json:"user_id" db:"user_id"`
|
|
ChatJID string `json:"chat_jid" db:"chat_jid"`
|
|
SenderJID string `json:"sender_jid" db:"sender_jid"`
|
|
MessageID string `json:"message_id" db:"message_id"`
|
|
Timestamp time.Time `json:"timestamp" db:"timestamp"`
|
|
MessageType string `json:"message_type" db:"message_type"`
|
|
TextContent string `json:"text_content" db:"text_content"`
|
|
MediaLink string `json:"media_link" db:"media_link"`
|
|
QuotedMessageID string `json:"quoted_message_id,omitempty" db:"quoted_message_id"`
|
|
DataJson string `json:"data_json" db:"datajson"`
|
|
}
|
|
|
|
func (s *server) saveMessageToHistory(userID, chatJID, senderJID, messageID, messageType, textContent, mediaLink, quotedMessageID, dataJson string) error {
|
|
query := `INSERT INTO message_history (user_id, chat_jid, sender_jid, message_id, timestamp, message_type, text_content, media_link, quoted_message_id, datajson)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`
|
|
if s.db.DriverName() == "sqlite" {
|
|
query = `INSERT INTO message_history (user_id, chat_jid, sender_jid, message_id, timestamp, message_type, text_content, media_link, quoted_message_id, datajson)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
|
|
}
|
|
_, err := s.db.Exec(query, userID, chatJID, senderJID, messageID, time.Now(), messageType, textContent, mediaLink, quotedMessageID, dataJson)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save message to history: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *server) trimMessageHistory(userID, chatJID string, limit int) error {
|
|
var queryHistory, querySecrets string
|
|
|
|
if s.db.DriverName() == "postgres" {
|
|
queryHistory = `
|
|
DELETE FROM message_history
|
|
WHERE id IN (
|
|
SELECT id FROM message_history
|
|
WHERE user_id = $1 AND chat_jid = $2
|
|
ORDER BY timestamp DESC
|
|
OFFSET $3
|
|
)`
|
|
|
|
querySecrets = `
|
|
DELETE FROM whatsmeow_message_secrets
|
|
WHERE message_id IN (
|
|
SELECT message_id FROM message_history
|
|
WHERE user_id = $1 AND chat_jid = $2
|
|
ORDER BY timestamp DESC
|
|
OFFSET $3
|
|
)`
|
|
} else { // sqlite
|
|
queryHistory = `
|
|
DELETE FROM message_history
|
|
WHERE id IN (
|
|
SELECT id FROM message_history
|
|
WHERE user_id = ? AND chat_jid = ?
|
|
ORDER BY timestamp DESC
|
|
LIMIT -1 OFFSET ?
|
|
)`
|
|
|
|
querySecrets = `
|
|
DELETE FROM whatsmeow_message_secrets
|
|
WHERE message_id IN (
|
|
SELECT message_id FROM message_history
|
|
WHERE user_id = ? AND chat_jid = ?
|
|
ORDER BY timestamp DESC
|
|
LIMIT -1 OFFSET ?
|
|
)`
|
|
}
|
|
|
|
if _, err := s.db.Exec(querySecrets, userID, chatJID, limit); err != nil {
|
|
return fmt.Errorf("failed to trim message secrets: %w", err)
|
|
}
|
|
|
|
if _, err := s.db.Exec(queryHistory, userID, chatJID, limit); err != nil {
|
|
return fmt.Errorf("failed to trim message history: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|