From 01a39db9d6f9b59546e0cc98ef634c1e9ad44384 Mon Sep 17 00:00:00 2001 From: io Date: Wed, 16 Jun 2021 03:49:34 +0000 Subject: [PATCH] rewrite reply.py too --- functions.py | 1 - pleroma.py | 91 ++++++++++++++++++++++++++----- reply.py | 148 +++++++++++++++++++++++++-------------------------- 3 files changed, 151 insertions(+), 89 deletions(-) diff --git a/functions.py b/functions.py index 0f90445..d0a549e 100755 --- a/functions.py +++ b/functions.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # SPDX-License-Identifier: EUPL-1.2 import re diff --git a/pleroma.py b/pleroma.py index 3c4308f..01d806d 100644 --- a/pleroma.py +++ b/pleroma.py @@ -1,24 +1,28 @@ # SPDX-License-Identifier: EUPL-1.2 import sys +import json import aiohttp - -USER_AGENT = ( - 'pleroma-ebooks (https://github.com/ioistired/pleroma-ebooks); ' - 'aiohttp/{aiohttp.__version__}; ' - 'python/{py_version}' -) +from http import HTTPStatus def http_session_factory(headers={}): - return aiohttp.ClientSession( - headers={'User-Agent': USER_AGENT, **headers}, - raise_for_status=True, + py_version = '.'.join(map(str, sys.version_info)) + user_agent = ( + 'pleroma-ebooks (https://github.com/ioistired/pleroma-ebooks); ' + 'aiohttp/{aiohttp.__version__}; ' + 'python/{py_version}' ) + return aiohttp.ClientSession( + headers={'User-Agent': user_agent, **headers}, + ) + +class BadRequest(Exception): + pass class Pleroma: def __init__(self, *, api_base_url, access_token): self.api_base_url = api_base_url.rstrip('/') - py_version = '.'.join(map(str, sys.version_info)) + self.access_token = access_token self._session = http_session_factory({'Authorization': 'Bearer ' + access_token}) self._logged_in_id = None @@ -31,6 +35,9 @@ class Pleroma: async def request(self, method, path, **kwargs): async with self._session.request(method, self.api_base_url + path, **kwargs) as resp: + if resp.status == HTTPStatus.BAD_REQUEST: + raise BadRequest((await resp.json())['error']) + resp.raise_for_status() return await resp.json() async def verify_credentials(self): @@ -47,12 +54,21 @@ class Pleroma: account_id = account_id or await self._get_logged_in_id() return await self.request('GET', f'/api/v1/accounts/{account_id}/following') + @staticmethod + def _unpack_id(obj): + if isinstance(obj, dict) and 'id' in obj: + return obj['id'] + return obj + + async def status_context(self, id): + id = self._unpack_id(id) + return await self.request('GET', f'/api/v1/statuses/{id}/context') + async def post(self, content, *, in_reply_to_id=None, cw=None, visibility=None): if visibility not in {None, 'private', 'public', 'unlisted', 'direct'}: raise ValueError('invalid visibility', visibility) - if isinstance(in_reply_to_id, dict) and 'id' in in_reply_to_id: - in_reply_to_id = in_reply_to_id['id'] + in_reply_to_id = self._unpack_id(in_reply_to_id) data = dict(status=content, in_reply_to_id=in_reply_to_id) if visibility is not None: @@ -73,8 +89,55 @@ class Pleroma: status = ''.join('@' + x + ' ' for x in mentioned_accounts.values()) + content - visibility = to_status['visibility'] - if cw is None and 'spoiler_text' in to_status: + visibility = 'unlisted' if to_status['visibility'] == 'public' else to_status['visibility'] + if cw is None and 'spoiler_text' in to_status and to_status['spoiler_text']: cw = 're: ' + to_status['spoiler_text'] return await self.post(content, in_reply_to_id=to_status['id'], cw=cw, visibility=visibility) + + async def favorite(self, id): + id = self._unpack_id(id) + return await self.request('POST', f'/api/v1/statuses/{id}/favourite') + + async def unfavorite(self, id): + id = self._unpack_id(id) + return await self.request('POST', f'/api/v1/statuses/{id}/unfavourite') + + async def react(self, id, reaction): + id = self._unpack_id(id) + return await self.request('PUT', f'/api/v1/pleroma/statuses/{id}/reactions/{reaction}') + + async def remove_reaction(self, id, reaction): + id = self._unpack_id(id) + return await self.request('DELETE', f'/api/v1/pleroma/statuses/{id}/reactions/{reaction}') + + async def pin(self, id): + id = self._unpack_id(id) + return await self.request('POST', f'/api/v1/statuses/{id}/pin') + + async def unpin(self, id): + id = self._unpack_id(id) + return await self.request('POST', f'/api/v1/statuses/{id}/unpin') + + async def stream(self, stream_name, *, target_event_type=None): + async with self._session.ws_connect( + self.api_base_url + f'/api/v1/streaming?stream={stream_name}&access_token={self.access_token}' + ) as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.TEXT: + event = msg.json() + # the only event type that doesn't define `payload` is `filters_changed` + if event['event'] == 'filters_changed': + yield event + elif target_event_type is None or event['event'] == target_event_type: + # don't ask me why the payload is also JSON encoded smh + yield json.loads(event['payload']) + + async def stream_notifications(self): + async for notif in self.stream('user:notification', target_event_type='notification'): + yield notif + + async def stream_mentions(self): + async for notif in self.stream_notifications(): + if notif['type'] == 'mention': + yield notif diff --git a/reply.py b/reply.py index dd694dd..927a98c 100755 --- a/reply.py +++ b/reply.py @@ -1,89 +1,89 @@ #!/usr/bin/env python3 -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# SPDX-License-Identifier: EUPL-1.2 -import mastodon -import re, json, argparse +import re +import anyio +import pleroma import functions +import contextlib -parser = argparse.ArgumentParser(description='Reply service. Leave running in the background.') -parser.add_argument( - '-c', '--cfg', dest='cfg', default='config.json', nargs='?', - help="Specify a custom location for config.json.") +def parse_args(): + return functions.arg_parser_factory(description='Reply service. Leave running in the background.').parse_args() -args = parser.parse_args() +class ReplyBot: + def __init__(self, cfg): + self.cfg = cfg + self.pleroma = pleroma.Pleroma(access_token=cfg['access_token'], api_base_url=cfg['site']) -cfg = json.load(open(args.cfg, 'r')) + async def run(self): + async with self.pleroma as self.pleroma: + self.me = (await self.pleroma.me())['id'] + self.follows = frozenset(user['id'] for user in await self.pleroma.following(self.me)) + async for notification in self.pleroma.stream_mentions(): + await self.process_notification(notification) -client = mastodon.Mastodon( - client_id=cfg['client']['id'], - client_secret=cfg['client']['secret'], - access_token=cfg['secret'], - api_base_url=cfg['site']) + async def process_notification(self, notification): + acct = "@" + notification['account']['acct'] # get the account's @ + post_id = notification['status']['id'] + context = await self.pleroma.status_context(post_id) + # check if we've already been participating in this thread + if self.check_thread_length(context): + return -def extract_toot(toot): - text = functions.extract_toot(toot) - text = re.sub(r"^@[^@]+@[^ ]+\s*", r"", text) # remove the initial mention - text = text.lower() # treat text as lowercase for easier keyword matching (if this bot uses it) - return text + content = self.extract_toot(notification['status']['content']) + if content in {'pin', 'unpin'}: + await self.process_command(context, notification, content) + else: + await self.reply(notification) + def check_thread_length(self, context) -> bool: + """return whether the thread is too long to reply to""" + posts = 0 + for post in context['ancestors']: + if post['account']['id'] == self.me: + posts += 1 + if posts >= self.cfg['max_thread_length']: + return True -class ReplyListener(mastodon.StreamListener): - def on_notification(self, notification): # listen for notifications - if notification['type'] == 'mention': # if we're mentioned: - acct = "@" + notification['account']['acct'] # get the account's @ - post_id = notification['status']['id'] + return False - # check if we've already been participating in this thread - try: - context = client.status_context(post_id) - except: - print("failed to fetch thread context") - return - me = client.account_verify_credentials()['id'] - posts = 0 - for post in context['ancestors']: - if post['account']['id'] == me: - pin = post["id"] # Only used if pin is called, but easier to call here - posts += 1 - if posts >= cfg['max_thread_length']: - # stop replying - print("didn't reply (max_thread_length exceeded)") - return + async def process_command(self, context, notification, command): + post_id = notification['status']['id'] + if notification['account']['id'] not in self.follows: # this user is unauthorized + await self.pleroma.react(post_id, '❌') + return - mention = extract_toot(notification['status']['content']) - if (mention == "pin") or (mention == "unpin"): # check for keywords - print("Found pin/unpin") - # get a list of people the bot is following - validusers = client.account_following(me) - for user in validusers: - if user["id"] == notification["account"]["id"]: # user is #valid - print("User is valid") - visibility = notification['status']['visibility'] - if visibility == "public": - visibility = "unlisted" - if mention == "pin": - print("pin received, pinning") - client.status_pin(pin) - client.status_post("Toot pinned!", post_id, visibility=visibility, spoiler_text=cfg['cw']) - else: - print("unpin received, unpinning") - client.status_post("Toot unpinned!", post_id, visibility=visibility, spoiler_text=cfg['cw']) - client.status_unpin(pin) - else: - print("User is not valid") - else: - toot = functions.make_toot(cfg) # generate a toot - toot = acct + " " + toot # prepend the @ - print(acct + " says " + mention) # logging - visibility = notification['status']['visibility'] - if visibility == "public": - visibility = "unlisted" - client.status_post(toot, post_id, visibility=visibility, spoiler_text=cfg['cw']) # send toost - print("replied with " + toot) # logging + # find the post the user is talking about + for post in context['ancestors']: + if post['id'] == notification['status']['in_reply_to_id']: + target_post_id = post['id'] + try: + await (self.pleroma.pin if command == 'pin' else self.pleroma.unpin)(target_post_id) + except pleroma.BadRequest as exc: + async with anyio.create_task_group() as tg: + tg.start_soon(self.pleroma.react, post_id, '❌') + tg.start_soon(self.pleroma.reply, notification['status'], 'Error: ' + exc.args[0]) + else: + await self.pleroma.react(post_id, '✅') -rl = ReplyListener() -client.stream_user(rl) # go! + async def reply(self, notification): + toot = functions.make_toot(self.cfg) # generate a toot + await self.pleroma.reply(notification['status'], toot, cw=self.cfg['cw']) + + @staticmethod + def extract_toot(toot): + text = functions.extract_toot(toot) + text = re.sub(r"^@\S+\s", r"", text) # remove the initial mention + text = text.lower() # treat text as lowercase for easier keyword matching (if this bot uses it) + return text + +async def amain(): + args = parse_args() + cfg = functions.load_config(args.cfg) + await ReplyBot(cfg).run() + +if __name__ == '__main__': + with contextlib.suppress(KeyboardInterrupt): + anyio.run(amain)