aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.rs3
-rw-r--r--src/plugins/factoids/database.rs3
-rw-r--r--src/plugins/remind/database.rs128
-rw-r--r--src/plugins/remind/mod.rs80
-rw-r--r--src/plugins/remind/parser.rs41
-rw-r--r--src/plugins/tell/database.rs3
6 files changed, 203 insertions, 55 deletions
diff --git a/src/main.rs b/src/main.rs
index 3f56c50..61cc839 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -126,8 +126,7 @@ fn run() -> Result<(), Error> {
let pool = Arc::new(pool);
bot.add_plugin(Factoids::new(pool.clone()));
bot.add_plugin(Tell::new(pool.clone()));
- // TODO Use mysql pool
- bot.add_plugin(Remind::new(HashMap::new()));
+ bot.add_plugin(Remind::new(pool.clone()));
info!("Connected to MySQL server")
}
Err(e) => {
diff --git a/src/plugins/factoids/database.rs b/src/plugins/factoids/database.rs
index 702834f..6cd979e 100644
--- a/src/plugins/factoids/database.rs
+++ b/src/plugins/factoids/database.rs
@@ -1,6 +1,3 @@
-#[cfg(feature = "mysql")]
-extern crate dotenv;
-
use std::collections::HashMap;
#[cfg(feature = "mysql")]
use std::sync::Arc;
diff --git a/src/plugins/remind/database.rs b/src/plugins/remind/database.rs
index c0c127e..e434ec0 100644
--- a/src/plugins/remind/database.rs
+++ b/src/plugins/remind/database.rs
@@ -1,14 +1,30 @@
-#[cfg(feature = "mysql")]
-extern crate dotenv;
-
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::fmt;
+#[cfg(feature = "mysql")]
+use std::sync::Arc;
+
+#[cfg(feature = "mysql")]
+use diesel::mysql::MysqlConnection;
+#[cfg(feature = "mysql")]
+use diesel::prelude::*;
+#[cfg(feature = "mysql")]
+use r2d2::Pool;
+#[cfg(feature = "mysql")]
+use r2d2_diesel::ConnectionManager;
+
+#[cfg(feature = "mysql")]
+use failure::ResultExt;
+
use chrono::NaiveDateTime;
use super::error::*;
+#[cfg(feature = "mysql")]
+static LAST_ID_SQL: &'static str = "SELECT LAST_INSERT_ID()";
+
+#[cfg_attr(feature = "mysql", derive(Queryable))]
#[derive(Clone, Debug)]
pub struct Event {
pub id: i64,
@@ -16,7 +32,7 @@ pub struct Event {
pub content: String,
pub author: String,
pub time: NaiveDateTime,
- pub repeat: Option<u64>,
+ pub repeat: Option<i64>,
}
impl fmt::Display for Event {
@@ -29,18 +45,20 @@ impl fmt::Display for Event {
}
}
+#[cfg_attr(feature = "mysql", derive(Insertable))]
+#[cfg_attr(feature = "mysql", table_name = "events")]
#[derive(Debug)]
pub struct NewEvent<'a> {
pub receiver: &'a str,
pub content: &'a str,
pub author: &'a str,
pub time: &'a NaiveDateTime,
- pub repeat: Option<u64>,
+ pub repeat: Option<i64>,
}
pub trait Database: Send + Sync {
- fn insert_event(&mut self, event: &NewEvent) -> Result<(), RemindError>;
- fn update_event_time(&mut self, id: i64, &NaiveDateTime) -> Result<(), RemindError>;
+ fn insert_event(&mut self, event: &NewEvent) -> Result<i64, RemindError>;
+ fn update_event_time(&mut self, id: i64, time: &NaiveDateTime) -> Result<(), RemindError>;
fn get_events_before(&self, time: &NaiveDateTime) -> Result<Vec<Event>, RemindError>;
fn get_user_events(&self, user: &str) -> Result<Vec<Event>, RemindError>;
fn get_event(&self, id: i64) -> Result<Event, RemindError>;
@@ -49,7 +67,7 @@ pub trait Database: Send + Sync {
// HashMap
impl Database for HashMap<i64, Event> {
- fn insert_event(&mut self, event: &NewEvent) -> Result<(), RemindError> {
+ fn insert_event(&mut self, event: &NewEvent) -> Result<i64, RemindError> {
let mut id = 0;
while self.contains_key(&id) {
id += 1;
@@ -65,7 +83,7 @@ impl Database for HashMap<i64, Event> {
};
match self.insert(id, event) {
- None => Ok(()),
+ None => Ok(id),
Some(_) => Err(ErrorKind::Duplicate)?,
}
}
@@ -126,3 +144,95 @@ impl Database for HashMap<i64, Event> {
}
}
}
+
+#[cfg(feature = "mysql")]
+mod schema {
+ table! {
+ events (id) {
+ id -> Bigint,
+ receiver -> Varchar,
+ content -> Text,
+ author -> Varchar,
+ time -> Timestamp,
+ repeat -> Nullable<Bigint>,
+ }
+ }
+}
+
+#[cfg(feature = "mysql")]
+use self::schema::events;
+
+#[cfg(feature = "mysql")]
+impl Database for Arc<Pool<ConnectionManager<MysqlConnection>>> {
+ fn insert_event(&mut self, event: &NewEvent) -> Result<i64, RemindError> {
+ use diesel::{self, dsl::sql, types::Bigint};
+ let conn = &*self.get().context(ErrorKind::NoConnection)?;
+
+ diesel::insert_into(events::table)
+ .values(event)
+ .execute(conn)
+ .context(ErrorKind::MysqlError)?;
+
+ let id = sql::<Bigint>(LAST_ID_SQL)
+ .get_result(conn)
+ .context(ErrorKind::MysqlError)?;
+
+ Ok(id)
+ }
+
+ fn update_event_time(&mut self, id: i64, time: &NaiveDateTime) -> Result<(), RemindError> {
+ use self::events::columns;
+ use diesel;
+ let conn = &*self.get().context(ErrorKind::NoConnection)?;
+
+ match diesel::update(events::table.filter(columns::id.eq(id)))
+ .set(columns::time.eq(time))
+ .execute(conn)
+ {
+ Ok(0) => Err(ErrorKind::NotFound)?,
+ Ok(_) => Ok(()),
+ Err(e) => Err(e).context(ErrorKind::MysqlError)?,
+ }
+ }
+
+ fn get_events_before(&self, time: &NaiveDateTime) -> Result<Vec<Event>, RemindError> {
+ use self::events::columns;
+ let conn = &*self.get().context(ErrorKind::NoConnection)?;
+
+ Ok(events::table
+ .filter(columns::time.lt(time))
+ .load::<Event>(conn)
+ .context(ErrorKind::MysqlError)?)
+ }
+
+ fn get_user_events(&self, user: &str) -> Result<Vec<Event>, RemindError> {
+ use self::events::columns;
+ let conn = &*self.get().context(ErrorKind::NoConnection)?;
+
+ Ok(events::table
+ .filter(columns::receiver.eq(user))
+ .load::<Event>(conn)
+ .context(ErrorKind::MysqlError)?)
+ }
+
+ fn get_event(&self, id: i64) -> Result<Event, RemindError> {
+ let conn = &*self.get().context(ErrorKind::NoConnection)?;
+
+ Ok(events::table
+ .find(id)
+ .first(conn)
+ .context(ErrorKind::MysqlError)?)
+ }
+
+ fn delete_event(&mut self, id: i64) -> Result<(), RemindError> {
+ use self::events::columns;
+ use diesel;
+
+ let conn = &*self.get().context(ErrorKind::NoConnection)?;
+ match diesel::delete(events::table.filter(columns::id.eq(id))).execute(conn) {
+ Ok(0) => Err(ErrorKind::NotFound)?,
+ Ok(_) => Ok(()),
+ Err(e) => Err(e).context(ErrorKind::MysqlError)?,
+ }
+ }
+}
diff --git a/src/plugins/remind/mod.rs b/src/plugins/remind/mod.rs
index 8f9b628..71f9d03 100644
--- a/src/plugins/remind/mod.rs
+++ b/src/plugins/remind/mod.rs
@@ -58,12 +58,12 @@ fn run<T: Database>(client: &IrcClient, db: Arc<RwLock<T>>) {
debug!("Sent reminder {:?}", event);
if let Some(repeat) = event.repeat {
- let next_time = event.time + chrono::Duration::seconds(repeat as i64);
+ let next_time = event.time + chrono::Duration::seconds(repeat);
if let Err(e) = db.write().update_event_time(event.id, &next_time) {
error!("Failed to update reminder: {}", e);
} else {
- debug!("Updated time on: {:?}", event);
+ debug!("Updated time");
}
} else if let Err(e) = db.write().delete_event(event.id) {
error!("Failed to delete reminder: {}", e);
@@ -103,33 +103,50 @@ impl<T: 'static + Database> Remind<T> {
}
}
- fn set(&self, command: PluginCommand) -> Result<&str, RemindError> {
- let parser = CommandParser::try_from_tokens(command.tokens)?;
+ fn user_cmd(&self, command: PluginCommand) -> Result<String, RemindError> {
+ let parser = CommandParser::parse_target(command.tokens)?;
+
+ self.set(parser, &command.source)
+ }
+
+ fn me_cmd(&self, command: PluginCommand) -> Result<String, RemindError> {
+ let source = command.source.clone();
+ let parser = CommandParser::with_target(command.tokens, command.source)?;
+
+ self.set(parser, &source)
+ }
+
+ fn set(&self, parser: CommandParser, author: &str) -> Result<String, RemindError> {
debug!("parser: {:?}", parser);
- let mut target = parser.get_target();
- if target == "me" {
- target = &command.source;
- }
+ let target = parser.get_target();
+ let time = parser.get_time(Duration::from_secs(120))?;
let event = database::NewEvent {
receiver: target,
content: &parser.get_message(),
- author: &command.source,
- time: &parser.get_time(Duration::from_secs(120))?,
+ author: author,
+ time: &time,
repeat: parser
- .get_repeat(Duration::from_secs(300))?
- .map(|d| d.as_secs()),
+ .get_repeat(Duration::from_secs(600))?
+ .map(|d| d.as_secs() as i64),
};
debug!("New event: {:?}", event);
- Ok(self.events.write().insert_event(&event).map(|()| "Got it")?)
+ Ok(self.events
+ .write()
+ .insert_event(&event)
+ .map(|id| format!("Created reminder with id {} at {} UTC", id, time))?)
}
fn list(&self, user: &str) -> Result<String, RemindError> {
let mut events = self.events.read().get_user_events(user)?;
+ if events.is_empty() {
+ Err(ErrorKind::NotFound)?;
+ }
+
let mut list = events.remove(0).to_string();
for ev in events {
list.push_str("\r\n");
@@ -145,7 +162,10 @@ impl<T: 'static + Database> Remind<T> {
.remove(0)
.parse::<i64>()
.context(ErrorKind::Parsing)?;
- let event = self.events.read().get_event(id)?;
+ let event = self.events
+ .read()
+ .get_event(id)
+ .context(ErrorKind::NotFound)?;
if event.receiver.eq_ignore_ascii_case(&command.source)
|| event.author.eq_ignore_ascii_case(&command.source)
@@ -169,15 +189,18 @@ impl<T: 'static + Database> Remind<T> {
}
impl<T: Database> Plugin for Remind<T> {
- fn execute(&self, client: &IrcClient, _: &Message) -> ExecutionStatus {
- let mut has_reminder = self.has_reminder.write();
- if !*has_reminder {
- let events = Arc::clone(&self.events);
- let client = client.clone();
+ fn execute(&self, client: &IrcClient, msg: &Message) -> ExecutionStatus {
+ if let Command::JOIN(_, _, _) = msg.command {
+ let mut has_reminder = self.has_reminder.write();
+
+ if !*has_reminder {
+ let events = Arc::clone(&self.events);
+ let client = client.clone();
- spawn(move || run(&client, events));
+ spawn(move || run(&client, events));
- *has_reminder = true;
+ *has_reminder = true;
+ }
}
ExecutionStatus::Done
@@ -198,7 +221,8 @@ impl<T: Database> Plugin for Remind<T> {
let sub_command = command.tokens.remove(0);
let response = match sub_command.as_ref() {
- "user" => self.set(command).map(|s| s.to_owned()),
+ "user" => self.user_cmd(command),
+ "me" => self.me_cmd(command),
"delete" => self.delete(command).map(|s| s.to_owned()),
"list" => self.list(&source),
"help" => Ok(self.help().to_owned()),
@@ -241,7 +265,7 @@ pub mod error {
#[error = "RemindError"]
pub enum ErrorKind {
/// Invalid command error
- #[fail(display = "Incorrect Command. Send \"currency help\" for help.")]
+ #[fail(display = "Incorrect Command. Send \"remind help\" for help.")]
InvalidCommand,
/// Missing message error
@@ -287,5 +311,15 @@ pub mod error {
/// Not found error
#[fail(display = "No events found")]
NotFound,
+
+ /// MySQL error
+ #[cfg(feature = "mysql")]
+ #[fail(display = "Failed to execute MySQL Query")]
+ MysqlError,
+
+ /// No connection error
+ #[cfg(feature = "mysql")]
+ #[fail(display = "No connection to the database")]
+ NoConnection,
}
}
diff --git a/src/plugins/remind/parser.rs b/src/plugins/remind/parser.rs
index 2dbb040..e027aba 100644
--- a/src/plugins/remind/parser.rs
+++ b/src/plugins/remind/parser.rs
@@ -27,20 +27,30 @@ enum ParseState {
}
impl CommandParser {
- pub fn try_from_tokens(tokens: Vec<String>) -> Result<Self, RemindError> {
- if tokens.is_empty() {
- return Err(ErrorKind::MissingReceiver.into());
+ pub fn parse_target(mut tokens: Vec<String>) -> Result<Self, RemindError> {
+ let mut parser = CommandParser::default();
+
+ if let Some(target) = tokens.pop() {
+ parser.target = target;
+ } else {
+ Err(ErrorKind::MissingReceiver)?;
}
+ parser.parse_tokens(tokens)
+ }
+
+ pub fn with_target(tokens: Vec<String>, target: String) -> Result<Self, RemindError> {
let mut parser = CommandParser::default();
- let mut state = ParseState::None;
+ parser.target = target;
- let mut iter = tokens.into_iter();
- parser.target = iter.next()
- .expect("This should be guaranteed by the length check");
+ parser.parse_tokens(tokens)
+ }
+ fn parse_tokens(mut self, tokens: Vec<String>) -> Result<Self, RemindError> {
+ let mut state = ParseState::None;
let mut cur_str = String::new();
- while let Some(token) = iter.next() {
+
+ for token in tokens {
let next_state = match token.as_ref() {
"on" => ParseState::On,
"at" => ParseState::At,
@@ -58,30 +68,31 @@ impl CommandParser {
if next_state != state {
if state != ParseState::None {
- parser = parser.add_string_by_state(&state, cur_str)?;
+ self = self.add_string_by_state(&state, cur_str)?;
cur_str = String::new();
}
state = next_state;
}
}
- parser = parser.add_string_by_state(&state, cur_str)?;
- if parser.message.is_none() {
+ self = self.add_string_by_state(&state, cur_str)?;
+
+ if self.message.is_none() {
return Err(ErrorKind::MissingMessage.into());
}
- if parser.in_duration.is_some() && parser.at_time.is_some()
- || parser.in_duration.is_some() && parser.on_date.is_some()
+ if self.in_duration.is_some() && self.at_time.is_some()
+ || self.in_duration.is_some() && self.on_date.is_some()
{
return Err(ErrorKind::AmbiguousTime.into());
}
- if parser.in_duration.is_none() && parser.at_time.is_none() && parser.on_date.is_none() {
+ if self.in_duration.is_none() && self.at_time.is_none() && self.on_date.is_none() {
return Err(ErrorKind::MissingTime.into());
}
- Ok(parser)
+ Ok(self)
}
fn add_string_by_state(self, state: &ParseState, string: String) -> Result<Self, RemindError> {
diff --git a/src/plugins/tell/database.rs b/src/plugins/tell/database.rs
index 522df5a..75789e4 100644
--- a/src/plugins/tell/database.rs
+++ b/src/plugins/tell/database.rs
@@ -1,6 +1,3 @@
-#[cfg(feature = "mysql")]
-extern crate dotenv;
-
use std::collections::HashMap;
#[cfg(feature = "mysql")]
use std::sync::Arc;