From 3bcff314117bd91722d54ad7b1e6533ba077fab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nikola=20Forr=C3=B3?= Date: Thu, 14 Mar 2019 13:37:16 +0100 Subject: [PATCH] Add single-sub-roll support, disable nested groups --- commands.py | 2 +- services/roll20.py | 169 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 143 insertions(+), 28 deletions(-) diff --git a/commands.py b/commands.py index 94e8151..3b37c5a 100644 --- a/commands.py +++ b/commands.py @@ -195,7 +195,7 @@ class Commands(object): try: result = Roll20.execute(formula) except Roll20Error: - raise CommandError('failed to interpret the formula') + raise CommandError('failed to interpret or execute the formula') else: return result diff --git a/services/roll20.py b/services/roll20.py index 5abd8bc..809c752 100644 --- a/services/roll20.py +++ b/services/roll20.py @@ -4,8 +4,10 @@ import parsy # TODO: functions -# TODO: dice matching -# FIXME: handle infinite loops +# TODO: nested groups + + +ROLL_LIMIT = 1000 def num(x): @@ -27,13 +29,89 @@ class Group(object): ' '.join([repr(x) for x in self.items]), self.keep, self.drop, self.succ, self.fail) def __str__(self): - kept = self._kept() - return '**{{** {0} **}}**'.format(' + '.join( - [str(x) if i in kept else '~~*{0}*~~'.format(x) for i, x in enumerate(self.items)])) + calculated = self._calculated() + kept = self.kept() + result = [] + for i, x in enumerate(self.items): + if i in kept: + if len(self.items) > 1 and self.succ: + if self.succ(calculated[i]): + result.append('__{0}__'.format(x)) + elif self.fail and self.fail(calculated[i]): + result.append('*{0}*'.format(x)) + else: + result.append(str(x)) + else: + result.append(str(x)) + else: + result.append('~~*{0}*~~'.format(x)) + return '**{{** {0} **}}**'.format(' + '.join(result)) - # TODO: handle single-sub-rolls - len(self.items) == 1 + def _subrolls(self, tree): + def traverse(node, subrolls): + try: + traverse(node.left, subrolls) + traverse(node.right, subrolls) + except AttributeError: + try: + node.result + subrolls.append(node) + except AttributeError: + return + subrolls = [] + traverse(tree, subrolls) + return subrolls + + def _make_expression(self, tree): + def exp(x): + def copy(node): + try: + return Operation(node.op, node.func, copy(node.left), copy(node.right)) + except AttributeError: + try: + node.result + return num(x) + except AttributeError: + return num(node) + new_tree = copy(tree) + try: + return new_tree.calc() + except AttributeError: + return num(new_tree) + return exp + + def _update_subrolls(self, subrolls): + results = [(r, i, j, s) for i, s in enumerate(subrolls) for j, r in enumerate(s.result) if j in s.kept()] + results = sorted(results, key=lambda x: (x[0], len(subrolls) - x[1], len(x[3].result) - x[2])) + if self.keep: + if self.keep[1]: + results = results[:round(self.keep[0])] + else: + results = results[-round(self.keep[0]):] + if self.drop: + if self.drop[1]: + results = results[:-round(self.drop[0])] + else: + results = results[round(self.drop[0]):] + for subroll in subrolls: + subroll.group_kept = [] + for _, _, i, subroll in results: + subroll.group_kept.append(i) def _calculated(self): + # FIXME: nested groups + if len(self.items) == 1: + subrolls = self._subrolls(self.items[0]) + self._update_subrolls(subrolls) + if len(subrolls) == 1: + exp = self._make_expression(self.items[0]) + if self.succ: + subrolls[0].succ = lambda x: self.succ(exp(x)) + if self.fail: + subrolls[0].fail = lambda x: self.fail(exp(x)) + return [subrolls[0].calc()] + elif self.succ: + raise RuntimeError('Multiple rolls in a single subroll are not allowed!') result = [] for item in self.items: try: @@ -42,7 +120,9 @@ class Group(object): result.append(num(item)) return result - def _kept(self): + def kept(self): + if len(self.items) == 1: + return [0] calculated = self._calculated() result = sorted(enumerate(calculated), key=lambda x: (x[1], len(calculated) - x[0])) if self.keep: @@ -59,16 +139,17 @@ class Group(object): def filtered(self): calculated = self._calculated() - kept = self._kept() + kept = self.kept() return [x for i, x in enumerate(calculated) if i in kept] def calc(self): filtered = self.filtered() - if self.succ: - result = len([x for x in filtered if self.succ(x)]) - if self.fail: - result -= len([x for x in filtered if self.fail(x)]) - return result + if len(self.items) > 1: + if self.succ: + result = len([x for x in filtered if self.succ(x)]) + if self.fail: + result -= len([x for x in filtered if self.fail(x)]) + return result return sum(filtered) @@ -83,7 +164,7 @@ class Operation(object): return '<{0} {1} {2}>'.format(self.op, repr(self.left), repr(self.right)) def __str__(self): - # FIXME: drop unneeded parentheses + # FIXME: get rid of unneeded parentheses return '( {0} {1} {2} )'.format(str(self.left), self.op.replace('*', '\\*'), str(self.right)) def calc(self): @@ -106,17 +187,33 @@ class Roll(object): self.drop = None self.succ = None self.fail = None + self.group_kept = None def __repr__(self): return ''.format( self.result, self.label, self.keep, self.drop, self.succ, self.fail) def __str__(self): - kept = self._kept() - return '**(** {0} **)**'.format(' + '.join( - [str(x) if i in kept else '~~*{0}*~~'.format(x) for i, x in enumerate(self.result)])) + kept = self.kept() + result = [] + for i, x in enumerate(self.result): + if i in kept: + if self.succ: + if self.succ(x): + result.append('__{0}__'.format(x)) + elif self.fail and self.fail(x): + result.append('*{0}*'.format(x)) + else: + result.append(str(x)) + else: + result.append(str(x)) + else: + result.append('~~*{0}*~~'.format(x)) + return '**(** {0} **)**'.format(' + '.join(result)) - def _kept(self): + def kept(self): + if self.group_kept is not None: + return self.group_kept result = sorted(enumerate(self.result), key=lambda x: (x[1], len(self.result) - x[0])) if self.keep: if self.keep[1]: @@ -131,7 +228,7 @@ class Roll(object): return list(list(zip(*sorted(result)))[0]) def filtered(self): - kept = self._kept() + kept = self.kept() return [x for i, x in enumerate(self.result) if i in kept] def calc(self): @@ -200,14 +297,16 @@ class Parser(object): @parsy.generate def group(): yield parsy.match_item('{') - result = yield group_simple + #result = yield group_simple + result = yield expression_additive result = Group([result]) while True: end = yield parsy.match_item('}') | parsy.success('') if end: break yield parsy.match_item(',') - other = yield group_simple + #other = yield group_simple + other = yield expression_additive result.items.append(other) return result @@ -348,10 +447,14 @@ class Parser(object): result = result[:] i = 0 while i < len(result): + cnt = 0 while condition(result[i], dice[0]): result[i] = random.choice(dice) if modifier == 'ro': break + cnt += 1 + if cnt > ROLL_LIMIT: + raise RuntimeError('Roll limit reached!') i += 1 return result, dice return 4, modify @@ -366,8 +469,12 @@ class Parser(object): i = 0 while i < len(result): sub = [result[i]] + cnt = 0 while condition(sub[-1], dice[-1]): sub.append(random.choice(dice)) + cnt += 1 + if cnt > ROLL_LIMIT: + raise RuntimeError('Roll limit reached!') result[i+1 : i+1] = [x - 1 for x in sub][1:] i += len(sub) return result, dice @@ -383,8 +490,12 @@ class Parser(object): i = 0 while i < len(result): sub = [result[i]] + cnt = 0 while condition(sub[-1], dice[-1]): sub.append(random.choice(dice)) + cnt += 1 + if cnt > ROLL_LIMIT: + raise RuntimeError('Roll limit reached!') result[i] = sum(sub) i += 1 return result, dice @@ -400,8 +511,12 @@ class Parser(object): i = 0 while i < len(result): sub = [result[i]] + cnt = 0 while condition(sub[-1], dice[-1]): sub.append(random.choice(dice)) + cnt += 1 + if cnt > ROLL_LIMIT: + raise RuntimeError('Roll limit reached!') result[i: i+1] = sub i += len(sub) return result, dice @@ -507,10 +622,10 @@ class Roll20(object): try: tokens = Parser.tokenize(formula) result = Parser.parse(tokens) - except parsy.ParseError as e: + try: + calculated = result.calc() + except AttributeError: + calculated = num(result) + return '{0} = __**{1}**__'.format(str(result), calculated) + except (parsy.ParseError, RuntimeError) as e: raise Roll20Error(str(e)) - try: - calculated = result.calc() - except AttributeError: - calculated = num(result) - return '{0} = __**{1}**__'.format(str(result), calculated)