diff --git a/clients/discord.py b/clients/discord.py index 50db0f1..68a603d 100644 --- a/clients/discord.py +++ b/clients/discord.py @@ -6,6 +6,28 @@ import discord from commands import CommandError +def cooldown(retries, timeout, failure): + def do_cooldown(function): + def wrapper(self, user, *args, **kwargs): + cooldowns = getattr(function, 'cooldowns', {}) + if user not in cooldowns: + cooldowns[user] = dict(tries=0, last_run=0) + cd = cooldowns[user] + if cd['tries'] < retries: + cd['tries'] += 1 + cd['last_run'] = time.time() + setattr(function, 'cooldowns', cooldowns) + return function(self, user, *args, **kwargs) + if time.time() - cd['last_run'] > timeout: + cd['tries'] = 1 + cd['last_run'] = time.time() + setattr(function, 'cooldowns', cooldowns) + return function(self, user, *args, **kwargs) + return failure(self, user, *args, **kwargs) + return wrapper + return do_cooldown + + class DiscordClient(discord.Client): def __init__(self, config, logger, commands): @@ -16,7 +38,6 @@ class DiscordClient(discord.Client): (re.compile(r'^(?P!|\?)(bella(gram|pics)|insta(gram|bella))$'), self._do_bellagram), (re.compile(r'^(?P!|\?)yt\s+(?P")?(?P.+)(?(q)")$'), self._do_yt), ] - self.cooldowns = {p.pattern: {} for p, _ in self.supported_commands} super(DiscordClient, self).__init__() async def start_(self): @@ -27,28 +48,18 @@ class DiscordClient(discord.Client): self.logger.info('Logged in as {0}'.format(self.user.name)) async def on_message(self, message): - def check_cooldown(pattern, user): - retries = self.config['Discord'].getint('cooldown_retries') - seconds = self.config['Discord'].getint('cooldown_seconds') - if user not in self.cooldowns[pattern]: - self.cooldowns[pattern][user] = dict(tries=0, time=0) - cooldown = self.cooldowns[pattern][user] - if cooldown['tries'] >= retries and time.time() - cooldown['time'] <= seconds: - return False - cooldown['tries'] = 1 if cooldown['tries'] >= retries else cooldown['tries'] + 1 - cooldown['time'] = time.time() - return True for pattern, action in self.supported_commands: m = pattern.match(message.content) if m: - if check_cooldown(pattern.pattern, message.author.id): - await action(message, **m.groupdict()) - else: - await self.send_message(message.channel, - 'Sorry {0}, you have to wait a while before running ' - 'the same command again'.format(message.author.mention)) + await action(message.author.id, message, **m.groupdict()) + + async def _cooldown_failure(self, user, message, **kwargs): + await self.send_message(message.channel, + 'Sorry {0}, you have to wait a while before running ' + 'the same command again'.format(message.author.mention)) - async def _do_bellagram(self, message, prefix, **kwargs): + @cooldown(retries=2, timeout=5*60, failure=_cooldown_failure) + async def _do_bellagram(self, user, message, prefix, **kwargs): try: bellagram = self.commands.bellagram() except CommandError as e: @@ -63,7 +74,8 @@ class DiscordClient(discord.Client): elif prefix == '?': await self.send_message(message.channel, bellagram['url']) - async def _do_yt(self, message, query, prefix, **kwargs): + @cooldown(retries=3, timeout=5*60, failure=_cooldown_failure) + async def _do_yt(self, user, message, query, prefix, **kwargs): try: result = self.commands.query_youtube(query) except CommandError as e: diff --git a/settings.cfg.example b/settings.cfg.example index f2c9db1..d7e7614 100644 --- a/settings.cfg.example +++ b/settings.cfg.example @@ -6,9 +6,6 @@ channels = lilialil [Discord] token = __DISCORD_TOKEN__ -# 5 minutes cooldown after 2 retries -cooldown_retries = 2 -cooldown_seconds = 300 [Twitch] api_url = https://api.twitch.tv/v5