From a50770e6b96eb8786126c1ec9e8abf7cade95678 Mon Sep 17 00:00:00 2001 From: "Gustavo \"Guz\" L. de Mello" Date: Wed, 14 Aug 2024 14:21:10 -0300 Subject: [PATCH] feat(guilddb,sqlite): support for multiple guilds --- internals/guilddb/guilddb.go | 64 +++++---- internals/guilddb/sqlite.go | 251 +++++++++++++++++++++++------------ 2 files changed, 206 insertions(+), 109 deletions(-) diff --git a/internals/guilddb/guilddb.go b/internals/guilddb/guilddb.go index 38fb9c0..82539ca 100644 --- a/internals/guilddb/guilddb.go +++ b/internals/guilddb/guilddb.go @@ -5,36 +5,41 @@ import ( "errors" ) -type Message struct { - ID string - ChannelID string - Language lang.Language - OriginID *string - OriginChannelID *string +type Guild struct { + ID string } type Channel struct { ID string + GuildID string Language lang.Language } - type ChannelGroup []Channel +type Message struct { + ID string + ChannelID string + GuildID string + Language lang.Language + OriginID *string + OriginChannelID *string +} + type GuildDB interface { // Selects and returns a Message from the database, based on the // key pair of Channel's ID and Message's ID. // // Will return ErrNotFound if no message is found or ErrInternal. - Message(channelID, messageID string) (Message, error) + Message(guildID, channelID, ID string) (Message, error) // Returns a slice of Messages with the provided Message.OriginChannelID and Message.OriginID. // // Will return ErrNotFound if no message is found (slice's length == 0) or ErrInternal. - MessagesWithOrigin(originChannelID, originID string) ([]Message, error) + MessagesWithOrigin(guildID, originChannelID, originID string) ([]Message, error) // Returns a Messages with the provided Message.OriginChannelID, Message.OriginID // and Message.Language. // // Will return ErrNotFound if no message is found or ErrInternal. - MessageWithOriginByLang(originChannelId, originId string, language lang.Language) (Message, error) + MessageWithOriginByLang(guildID, originChannelId, originId string, language lang.Language) (Message, error) // Inserts a new Message object in the database. // // Message.ChannelID and Message.ID must be a unique pair and not already @@ -46,58 +51,71 @@ type GuildDB interface { // is a translated one. // // Will return ErrNoAffect if the object already exists or ErrInternal. - MessageInsert(message Message) error + MessageInsert(m Message) error // Updates the Message object in the database. Message.ID and Message.ChannelID // are used to find the correct message. // // Will return ErrNoAffect if no object was updated or ErrInternal. - MessageUpdate(message Message) error + MessageUpdate(m Message) error // Deletes the Message object in the database. Message.ID and Message.ChannelID // are used to find the correct message. // // Will return ErrNoAffect if no object was deleted or ErrInternal. - MessageDelete(message Message) error + MessageDelete(m Message) error // Selects and returns a Channel from the database, based on the // ID provided. // // Will return ErrNotFound if no channel is found or ErrInternal. - Channel(channelID string) (Channel, error) + Channel(guildID, ID string) (Channel, error) // Inserts a new Channel object in the database. // // Channel.ID must be unique and not already in the database. // // Will return ErrNoAffect if the object already exists or ErrInternal. - ChannelInsert(channel Channel) error + ChannelInsert(c Channel) error // Updates the Channel object in the database. Channel.ID is used to find the // correct Channel. // // Will return ErrNoAffect if no object was updated or ErrInternal. - ChannelUpdate(channel Channel) error + ChannelUpdate(c Channel) error // Deletes the Channel object in the database. Channel.ID is used to find the // correct Channel. // - // Will return ErrNoAffect if no object was updated or ErrInternal. - ChannelDelete(channel Channel) error + // Will return ErrNoAffect if no object was deleted or ErrInternal. + ChannelDelete(c Channel) error // Selects and returns a ChannelGroup from the database. Finds a ChannelGroup // that has a Channel if the provided ID. // // Channels cannot be in two ChannelGroup at the same time. // // Will return ErrNotFound if no channel is found or ErrInternal. - ChannelGroup(channelID string) (ChannelGroup, error) + ChannelGroup(guildID, ID string) (ChannelGroup, error) // Inserts a new ChannelGroup object in the database. ChannelGroup must be unique // and not have Channels that are already in other groups. // // Will return ErrNoAffect if the object already exists or ErrInternal. - ChannelGroupInsert(group ChannelGroup) error + ChannelGroupInsert(g ChannelGroup) error // Updates the ChannelGroup object in the database. // // Will return ErrNoAffect if no object was updated or ErrInternal. - ChannelGroupUpdate(channel ChannelGroup) error + ChannelGroupUpdate(g ChannelGroup) error // Deletes the ChannelGroup object in the database. // - // Will return ErrNoAffect if no object was updated or ErrInternal. - ChannelGroupDelete(channel ChannelGroup) error + // Will return ErrNoAffect if no object was deleted or ErrInternal. + ChannelGroupDelete(g ChannelGroup) error + // Selects and returns a Guild from the database. + // + // Will return ErrNotFound if no Guild is found or ErrInternal. + Guild(ID string) (Guild, error) + // Inserts a new Guild object in the database. Guild.ID must be unique and + // not already in the database. + // + // Will return ErrNoAffect if the object already exists or ErrInternal. + GuildInsert(g Guild) error + // Delete a Guild from the database. Guild.ID is used to find the object. + // + // Will return ErrNoAffect if no object was deleted or ErrInternal. + GuildDelete(g Guild) error } var ErrNoAffect = errors.New("Not able to affect anything in the database") diff --git a/internals/guilddb/sqlite.go b/internals/guilddb/sqlite.go index a24e607..856c8a5 100644 --- a/internals/guilddb/sqlite.go +++ b/internals/guilddb/sqlite.go @@ -28,66 +28,77 @@ func (db *SQLiteDB) Close() error { } func (db *SQLiteDB) Prepare() error { - _, err := db.sql.Exec(` - CREATE TABLE IF NOT EXISTS messages ( - ID text NOT NULL, - ChannelID text NOT NULL, - Language text NOT NULL, - OriginID text, - OriginChannelID text, - PRIMARY KEY(ID, ChannelID), - FOREIGN KEY(ChannelID) REFERENCES channels(ID), - FOREIGN KEY(OriginID) REFERENCES messages(ID), - FOREIGN KEY(OriginChannelID) REFERENCES channels(ID) + if _, err := db.sql.Exec(` + CREATE TABLE IF NOT EXISTS guilds ( + ID text NOT NULL, + PRIMARY KEY(ID) ); - `) - if err != nil { + `); err != nil { return errors.Join(ErrInternal, err) } - _, err = db.sql.Exec(` + if _, err := db.sql.Exec(` CREATE TABLE IF NOT EXISTS channels ( + GuildID text NOT NULL, ID text NOT NULL, Language text NOT NULL, - PRIMARY KEY(ID) + PRIMARY KEY(ID, GuildID), + FOREIGN KEY(GuildID) REFERENCES guilds(ID) ); CREATE TABLE IF NOT EXISTS channel-groups ( - Channels text NOT NULL PRIMARY KEY + GuildID text NOT NULL, + Channels text NOT NULL, + PRIMARY KEY(Channels, GuildID), + FOREIGN KEY(GuildID) REFERENCES guilds(ID) ); - `) - if err != nil { + `); err != nil { + return errors.Join(ErrInternal, err) + } + + if _, err := db.sql.Exec(` + CREATE TABLE IF NOT EXISTS messages ( + GuildID text NOT NULL, + ChannelID text NOT NULL, + ID text NOT NULL, + Language text NOT NULL, + OriginChannelID text, + OriginID text, + PRIMARY KEY(ID, ChannelID, GuildID), + FOREIGN KEY(ChannelID) REFERENCES channels(ID), + FOREIGN KEY(GuildID) REFERENCES guilds(ID), + FOREIGN KEY(OriginID) REFERENCES messages(ID), + FOREIGN KEY(OriginChannelID) REFERENCES channels(ID) + ); + `); err != nil { return errors.Join(ErrInternal, err) } return nil } -func (db *SQLiteDB) Message(channelID, messageID string) (Message, error) { +func (db *SQLiteDB) Message(guildID, channelID, messageID string) (Message, error) { return db.selectMessage(` - WHERE "ID" = $1 AND "ChannelID" = $2 - `, messageID, channelID) - SELECT * FROM messages + WHERE "GuildID" = $1 AND "ChannelID" = $2 AND "ID" = $3 + `, guildID, channelID, messageID) } -func (db *SQLiteDB) MessagesWithOrigin(originID, originChannelID string) ([]Message, error) { +func (db *SQLiteDB) MessagesWithOrigin(guildID, originChannelID, originID string) ([]Message, error) { return db.selectMessages(` - WHERE "OriginID" = $1 AND "OriginChannelID" = $2 - `, originID, originChannelID) - SELECT * FROM messages + WHERE "GuildID" = $1 AND "OriginChannelID" = $2 AND "OriginID" = $3 + `, guildID, originChannelID, originID) } func (db *SQLiteDB) MessageWithOriginByLang( - originChannelID, originID string, + guildID, originChannelID, originID string, language lang.Language, ) (Message, error) { return db.selectMessage(` - SELECT * FROM messages - WHERE "OriginID" = $1 AND "OriginChannelID" = $2 AND "Language" = $3 - `, originID, originChannelID, language) + WHERE "GuildID" = $1 AND "OriginChannelID" = $2 AND "OriginID" = $3 AND "Language" = $4 + `, guildID, originChannelID, originID, language) } func (db *SQLiteDB) MessageInsert(m Message) error { - _, err := db.Channel(m.ChannelID) + _, err := db.Channel(m.GuildID, m.ChannelID) if errors.Is(err, ErrNotFound) { return errors.Join( ErrPreconditionFailed, @@ -102,9 +113,9 @@ func (db *SQLiteDB) MessageInsert(m Message) error { } r, err := db.sql.Exec(` - INSERT INTO messages (ID, ChannelID, Language, OriginID, OriginChannelID) - VALUES ($1, $2, $3, $4, $5) - `, m.ID, m.ChannelID, m.Language, m.OriginID, m.OriginChannelID) + INSERT INTO messages (GuildID, ChannelID, ID, Language, OriginChannelID, OriginID) + VALUES ($1, $2, $3, $4, $5, $6) + `, m.GuildID, m.ChannelID, m.ID, m.Language, m.OriginChannelID, m.OriginID) if err != nil { return errors.Join(ErrInternal, err) @@ -115,16 +126,17 @@ func (db *SQLiteDB) MessageInsert(m Message) error { return nil } -func (db *SQLiteDB) MessageUpdate(message Message) error { +func (db *SQLiteDB) MessageUpdate(m Message) error { r, err := db.sql.Exec(` UPDATE messages SET Language = $1, OriginChannelID = $2, OriginID = $3 - WHERE "ID" = $4 AND "ChannelID" = $5 - `, message.Language, - message.OriginChannelID, - message.OriginID, - message.ID, - message.ChannelID, + WHERE "GuildID" = $4 AND "ChannelID" = $5 AND "ID" = $6 + `, m.Language, + m.OriginChannelID, + m.OriginID, + m.GuildID, + m.ChannelID, + m.ID, ) if err != nil { @@ -136,11 +148,11 @@ func (db *SQLiteDB) MessageUpdate(message Message) error { return nil } -func (db *SQLiteDB) MessageDelete(message Message) error { +func (db *SQLiteDB) MessageDelete(m Message) error { _, err := db.sql.Exec(` DELETE channels - WHERE "OriginID" = $1 AND "OriginChannelID" = $2 - `, message.ID, message.ChannelID) + WHERE "GuildID" = $1 AND "OriginChannelID" = $2 AND "OriginID" = $3 + `, m.GuildID, m.ChannelID, m.ID) if err != nil { return errors.Join(ErrInternal, err) @@ -148,8 +160,8 @@ func (db *SQLiteDB) MessageDelete(message Message) error { r, err := db.sql.Exec(` DELETE channels - WHERE "ID" = $1 AND "ChannelID" = $2 - `, message.ID, message.ChannelID) + WHERE "GuildID" = $1 AND "ChannelID" = $2 AND "ID" = $3 + `, m.GuildID, m.ChannelID, m.ID) if err != nil { return errors.Join(ErrInternal, err) @@ -162,8 +174,11 @@ func (db *SQLiteDB) MessageDelete(message Message) error { func (db *SQLiteDB) selectMessage(query string, args ...any) (Message, error) { var m Message - err := db.sql.QueryRow(query, args...). - Scan(&m.ID, &m.ChannelID, &m.Language, &m.OriginID, &m.OriginChannelID) + err := db.sql.QueryRow(fmt.Sprintf(` + SELECT (GuildID, ChannelID, ID, Language, OriginChannelID, OriginID) FROM messages + %s + `, query), args...). + Scan(&m.GuildID, &m.ChannelID, &m.ID, &m.Language, &m.OriginChannelID, &m.OriginID) if errors.Is(err, sql.ErrNoRows) { return m, errors.Join(ErrNotFound, err) @@ -175,7 +190,10 @@ func (db *SQLiteDB) selectMessage(query string, args ...any) (Message, error) { } func (db *SQLiteDB) selectMessages(query string, args ...any) ([]Message, error) { - r, err := db.sql.Query(query, args...) + r, err := db.sql.Query(fmt.Sprintf(` + SELECT (GuildID, ChannelID, ID, Language, OriginChannelID, OriginID) FROM messages + %s + `, query), args...) if err != nil { return []Message{}, errors.Join(ErrInternal, err) @@ -185,7 +203,7 @@ func (db *SQLiteDB) selectMessages(query string, args ...any) ([]Message, error) for r.Next() { var m Message - err = r.Scan(&m.ID, &m.ChannelID, &m.Language, &m.OriginID, &m.OriginChannelID) + err = r.Scan(&m.GuildID, &m.ChannelID, &m.ID, &m.Language, &m.OriginChannelID, &m.OriginID) if err != nil { return ms, errors.Join( ErrInternal, @@ -206,18 +224,17 @@ func (db *SQLiteDB) selectMessages(query string, args ...any) ([]Message, error) return ms, err } -func (db *SQLiteDB) Channel(channelID string) (Channel, error) { +func (db *SQLiteDB) Channel(guildID, ID string) (Channel, error) { return db.selectChannel(` - SELECT (ID, Language) FROM channels - WHERE "ID" = $1 - `, channelID) + WHERE "GuildID" = $1 AND "ID" = $2 + `, guildID, ID) } func (db *SQLiteDB) ChannelInsert(c Channel) error { r, err := db.sql.Exec(` - INSERT INTO channels (ID, Language) - VALUES ($1, $2) - `, c.ID, c.Language) + INSERT INTO channels (GuildID, ID, Language) + VALUES ($1, $2, $3) + `, c.GuildID, c.ID, c.Language) if err != nil { return errors.Join(ErrInternal, err) @@ -228,12 +245,12 @@ func (db *SQLiteDB) ChannelInsert(c Channel) error { return nil } -func (db *SQLiteDB) ChannelUpdate(channel Channel) error { +func (db *SQLiteDB) ChannelUpdate(c Channel) error { r, err := db.sql.Exec(` UPDATE channels SET Language = $1 - WHERE "ID" = $2 - `, channel.Language, channel.ID) + WHERE "GuildID" = $2 AND "ID" = $3 + `, c.Language, c.GuildID, c.ID) if err != nil { return errors.Join(ErrInternal, err) @@ -244,11 +261,11 @@ func (db *SQLiteDB) ChannelUpdate(channel Channel) error { return nil } -func (db *SQLiteDB) ChannelDelete(channel Channel) error { +func (db *SQLiteDB) ChannelDelete(c Channel) error { r, err := db.sql.Exec(` DELETE channels - WHERE "ID" = $1 - `, channel.ID) + WHERE "GuildID" = $1 AND "ID" = $2 + `, c.ID, c.ID) if err != nil { return errors.Join(ErrInternal, err) @@ -259,13 +276,13 @@ func (db *SQLiteDB) ChannelDelete(channel Channel) error { return nil } -func (db *SQLiteDB) ChannelGroup(channelID string) (ChannelGroup, error) { +func (db *SQLiteDB) ChannelGroup(guildID, channelID string) (ChannelGroup, error) { var g string err := db.sql.QueryRow(` - SELECT (ID, Language) FROM channel-groups - WHERE "Channels" LIKE "%$1%" - `, channelID).Scan(&g) + SELECT (GuildID, ID, Language) FROM channel-groups + WHERE "GuildID" = $1 AND "Channels" LIKE "%$2%" + `, guildID, channelID).Scan(&g) if errors.Is(err, sql.ErrNoRows) { return ChannelGroup{}, errors.Join(ErrNotFound, err) @@ -285,9 +302,8 @@ func (db *SQLiteDB) ChannelGroup(channelID string) (ChannelGroup, error) { } cs, err := db.selectChannels(fmt.Sprintf(` - SELECT (ID, Language) FROM channels - WHERE %s - `, strings.Join(ids, " OR "))) + WHERE %s AND "GuildID" = $1 + `, strings.Join(ids, " OR ")), guildID) if errors.Is(err, ErrNotFound) || len(cs) != len(ids) { return ChannelGroup{}, errors.Join( @@ -302,17 +318,21 @@ func (db *SQLiteDB) ChannelGroup(channelID string) (ChannelGroup, error) { return cs, nil } -func (db *SQLiteDB) ChannelGroupInsert(group ChannelGroup) error { +func (db *SQLiteDB) ChannelGroupInsert(g ChannelGroup) error { + if len(g) != 0 { + return nil + } + var ids []string - for _, c := range group { + for _, c := range g { ids = append(ids, c.ID) } slices.Sort(ids) r, err := db.sql.Exec(` - INSERT INTO channel-groups (Channels) - VALUES ($1) - `, strings.Join(ids, ",")) + INSERT INTO channel-groups (GuildID, Channels) + VALUES ($1, $2) + `, g[0].GuildID, strings.Join(ids, ",")) if err != nil { return errors.Join(ErrInternal, err) @@ -323,9 +343,13 @@ func (db *SQLiteDB) ChannelGroupInsert(group ChannelGroup) error { return nil } -func (db *SQLiteDB) ChannelGroupUpdate(group ChannelGroup) error { +func (db *SQLiteDB) ChannelGroupUpdate(g ChannelGroup) error { + if len(g) != 0 { + return nil + } + var ids, idsq []string - for _, c := range group { + for _, c := range g { ids = append(ids, c.ID) idsq = append(idsq, "\"ID\" LIKE \""+c.ID+"\"") } @@ -335,9 +359,10 @@ func (db *SQLiteDB) ChannelGroupUpdate(group ChannelGroup) error { fmt.Sprintf(` UPDATE channel-groups SET Channels = $1 - WHERE %s + WHERE %s AND "GuildID" = $2 `, strings.Join(idsq, " OR ")), strings.Join(ids, ","), + g[0].GuildID, ) if err != nil { @@ -349,9 +374,13 @@ func (db *SQLiteDB) ChannelGroupUpdate(group ChannelGroup) error { return nil } -func (db *SQLiteDB) ChannelGroupDelete(group ChannelGroup) error { +func (db *SQLiteDB) ChannelGroupDelete(g ChannelGroup) error { + if len(g) != 0 { + return nil + } + var ids, idsq []string - for _, c := range group { + for _, c := range g { ids = append(ids, c.ID) idsq = append(idsq, "\"ID\" LIKE \""+c.ID+"\"") } @@ -360,9 +389,9 @@ func (db *SQLiteDB) ChannelGroupDelete(group ChannelGroup) error { r, err := db.sql.Exec( fmt.Sprintf(` DELETE FROM channel-groups - WHERE %s + WHERE %s AND "GuildID" = $1 `, strings.Join(idsq, " OR ")), - strings.Join(ids, ","), + g[0].GuildID, ) if err != nil { @@ -376,8 +405,10 @@ func (db *SQLiteDB) ChannelGroupDelete(group ChannelGroup) error { func (db *SQLiteDB) selectChannel(query string, args ...any) (Channel, error) { var c Channel - err := db.sql.QueryRow(query, args...). - Scan(&c.ID, &c.Language) + err := db.sql.QueryRow(fmt.Sprintf(` + SELECT (GuildID, ID, Language) FROM channels + %s + `, query), args...).Scan(&c.GuildID, &c.ID, &c.Language) if errors.Is(err, sql.ErrNoRows) { return c, errors.Join(ErrNotFound, err) @@ -389,7 +420,10 @@ func (db *SQLiteDB) selectChannel(query string, args ...any) (Channel, error) { } func (db *SQLiteDB) selectChannels(query string, args ...any) ([]Channel, error) { - r, err := db.sql.Query(query, args...) + r, err := db.sql.Query(fmt.Sprintf(` + SELECT (GuildID, ID, Language) FROM channels + %s + `, query), args...) if err != nil { return []Channel{}, errors.Join(ErrInternal, err) @@ -399,7 +433,7 @@ func (db *SQLiteDB) selectChannels(query string, args ...any) ([]Channel, error) for r.Next() { var c Channel - err = r.Scan(&c.ID, &c.Language) + err = r.Scan(&c.GuildID, &c.ID, &c.Language) if err != nil { return cs, errors.Join( ErrInternal, @@ -419,3 +453,48 @@ func (db *SQLiteDB) selectChannels(query string, args ...any) ([]Channel, error) } return cs, err } + +func (db *SQLiteDB) Guild(ID string) (Guild, error) { + var g Guild + + if err := db.sql.QueryRow(` + SELECT (ID) FROM guilds + WHERE "ID" = $1 + `, ID).Scan(g.ID); err != nil && errors.Is(err, sql.ErrNoRows) { + return Guild{}, errors.Join(ErrNotFound, err) + } else if err != nil { + return Guild{}, errors.Join(ErrInternal, err) + } + + return g, nil +} + +func (db *SQLiteDB) GuildInsert(g Guild) error { + r, err := db.sql.Exec(` + INSERT INTO guilds (ID) + VALUES ($1) + `, g.ID) + + if err != nil { + return errors.Join(ErrInternal, err) + } else if rows, _ := r.RowsAffected(); rows == 0 { + return ErrNoAffect + } + + return nil +} + +func (db *SQLiteDB) GuildDelete(g Guild) error { + r, err := db.sql.Exec(` + DELETE FROM guilds + WHERE "ID" = $1 + `, g.ID) + + if err != nil { + return errors.Join(ErrInternal, err) + } else if rows, _ := r.RowsAffected(); rows == 0 { + return ErrNoAffect + } + + return nil +}