From efed781fa160806888077aba435bd489d50b5141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Piboub=C3=A8s?= Date: Mon, 17 Aug 2015 00:36:18 +0200 Subject: [PATCH] Refactoring and making rtv OAuth-compliant --- rtv/__main__.py | 71 ++++++++++------- rtv/config.py | 11 ++- rtv/docs.py | 7 +- rtv/oauth.py | 184 ++++++++++++++++++++++++++++--------------- rtv/page.py | 42 +++++----- rtv/submission.py | 13 ++- rtv/subreddit.py | 26 ++++-- rtv/subscriptions.py | 8 +- 8 files changed, 239 insertions(+), 123 deletions(-) diff --git a/rtv/__main__.py b/rtv/__main__.py index b084de4..043244e 100644 --- a/rtv/__main__.py +++ b/rtv/__main__.py @@ -11,16 +11,16 @@ import configparser from . import config from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError -from .curses_helpers import curses_session +from .curses_helpers import curses_session, LoadScreen from .submission import SubmissionPage from .subreddit import SubredditPage from .docs import * -from .oauth import load_oauth_config, read_setting, write_setting, authorize +from .oauth import OAuthTool from .__version__ import __version__ __all__ = [] -def load_config(): +def open_config(): """ Search for a configuration file at the location ~/.rtv and attempt to load saved settings for things like the username and password. @@ -41,6 +41,15 @@ def load_config(): config.read(config_path) break + return config + +def load_rtv_config(): + """ + Attempt to load saved settings for things like the username and password. + """ + + config = open_config() + defaults = {} if config.has_section('rtv'): defaults = dict(config.items('rtv')) @@ -50,6 +59,18 @@ def load_config(): return defaults +def load_oauth_config(): + """ + Attempt to load saved OAuth settings + """ + + config = open_config() + + defaults = {} + if config.has_section('oauth'): + defaults = dict(config.items('oauth')) + + return defaults def command_line(): @@ -69,6 +90,13 @@ def command_line(): group.add_argument('-u', dest='username', help='reddit username') group.add_argument('-p', dest='password', help='reddit password') + oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH) + oauth_group.add_argument('--client-id', dest='client_id', help='OAuth app ID') + oauth_group.add_argument('--redurect-uri', dest='redirect_uri', help='OAuth app redirect URI') + oauth_group.add_argument('--auth-token', dest='authorization_token', help='OAuth authorization token') + oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token') + oauth_group.add_argument('--scope', dest='scope', help='OAuth app scope') + args = parser.parse_args() return args @@ -81,7 +109,8 @@ def main(): locale.setlocale(locale.LC_ALL, '') args = command_line() - local_config = load_config() + local_rtv_config = load_rtv_config() + local_oauth_config = load_oauth_config() # set the terminal title title = 'rtv {0}'.format(__version__) @@ -92,10 +121,14 @@ def main(): # Fill in empty arguments with config file values. Paramaters explicitly # typed on the command line will take priority over config file params. - for key, val in local_config.items(): + for key, val in local_rtv_config.items(): if getattr(args, key, None) is None: setattr(args, key, val) + for k, v in local_oauth_config.items(): + if getattr(args, k, None) is None: + setattr(args, k, v) + config.unicode = (not args.ascii) # Squelch SSL warnings for Ubuntu @@ -107,34 +140,19 @@ def main(): print('Connecting...') reddit = praw.Reddit(user_agent=AGENT) reddit.config.decode_html_entities = False - if read_setting(key="authorization_token") is None: - print('Hello OAuth login helper!') - authorize(reddit) - else: - oauth_config = load_oauth_config() - oauth_data = {} - if oauth_config.has_section('oauth'): - oauth_data = dict(oauth_config.items('oauth')) - - reddit.set_oauth_app_info(oauth_data['client_id'], - oauth_data['client_secret'], - oauth_data['redirect_uri']) - - reddit.set_access_credentials(scope=set(oauth_data['scope'].split('-')), - access_token=oauth_data['authorization_token'], - refresh_token=oauth_data['refresh_token']) - """if args.username: - # PRAW will prompt for password if it is None - reddit.login(args.username, args.password)""" with curses_session() as stdscr: + oauth = OAuthTool(reddit, stdscr, LoadScreen(stdscr)) + oauth.authorize() if args.link: - page = SubmissionPage(stdscr, reddit, url=args.link) + page = SubmissionPage(stdscr, reddit, oauth, url=args.link) page.loop() subreddit = args.subreddit or 'front' - page = SubredditPage(stdscr, reddit, subreddit) + page = SubredditPage(stdscr, reddit, oauth, subreddit) page.loop() except praw.errors.InvalidUserPass: print('Invalid password for username: {}'.format(args.username)) + except praw.errors.OAuthAppRequired: + print('Invalid OAuth app config parameters') except requests.ConnectionError: print('Connection timeout') except requests.HTTPError: @@ -150,7 +168,6 @@ def main(): pass finally: # Ensure sockets are closed to prevent a ResourceWarning - print(reddit.is_oauth_session()) reddit.handler.http.close() sys.exit(main()) diff --git a/rtv/config.py b/rtv/config.py index c59a16d..e7761af 100644 --- a/rtv/config.py +++ b/rtv/config.py @@ -2,4 +2,13 @@ Global configuration settings """ -unicode = True \ No newline at end of file +unicode = True + +""" +OAuth settings +""" + +oauth_client_id = 'nxoobnwO7mCP5A' +oauth_client_secret = 'praw_gapfill' +oauth_redirect_uri = 'https://rtv.theo-piboubes.fr/auth' +oauth_scope = 'edit-history-identity-mysubreddits-privatemessages-read-report-save-submit-subscribe-vote' diff --git a/rtv/docs.py b/rtv/docs.py index 8775728..cefd7b0 100644 --- a/rtv/docs.py +++ b/rtv/docs.py @@ -1,6 +1,6 @@ from .__version__ import __version__ -__all__ = ['AGENT', 'SUMMARY', 'AUTH', 'CONTROLS', 'HELP', 'COMMENT_FILE', +__all__ = ['AGENT', 'SUMMARY', 'AUTH', 'OAUTH', 'CONTROLS', 'HELP', 'COMMENT_FILE', 'SUBMISSION_FILE', 'COMMENT_EDIT_FILE'] AGENT = """\ @@ -17,6 +17,11 @@ Authenticating is required to vote and leave comments. If only a username is given, the program will display a secure prompt to enter a password. """ +OAUTH = """\ +Authentication is now done by OAuth, since PRAW will stop supporting login with +username and password soon. +""" + CONTROLS = """ Controls -------- diff --git a/rtv/oauth.py b/rtv/oauth.py index 868b69f..c5482b2 100644 --- a/rtv/oauth.py +++ b/rtv/oauth.py @@ -1,87 +1,145 @@ import configparser +import curses +import logging import os -import webbrowser +import time import uuid +import webbrowser -__all__ = [] +import praw -def get_config_file_path(): - HOME = os.path.expanduser('~') - XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config')) - config_paths = [ - os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'), - os.path.join(HOME, '.rtv') - ] +from . import config +from .curses_helpers import show_notification, prompt_input - # get the first existing config file - for config_path in config_paths: - if os.path.exists(config_path): - break +__all__ = ['token_validity', 'OAuthTool'] +_logger = logging.getLogger(__name__) - return config_path +token_validity = 3540 -def load_oauth_config(): - config = configparser.ConfigParser() - config_path = get_config_file_path() - config.read(config_path) +class OAuthTool(object): - return config + def __init__(self, reddit, stdscr=None, loader=None, + client_id=None, redirect_uri=None, scope=None): + self.reddit = reddit + self.stdscr = stdscr + self.loader = loader -def read_setting(key, section='oauth'): - config = load_oauth_config() + self.config = configparser.ConfigParser() + self.config_fp = None - try: - setting = config[section][key] - except KeyError: - return None + self.client_id = client_id or config.oauth_client_id + # Comply with PRAW's desperate need for client secret + self.client_secret = config.oauth_client_secret + self.redirect_uri = redirect_uri or config.oauth_redirect_uri - return setting + self.scope = scope or config.oauth_scope.split('-') -def write_setting(key, value, section='oauth'): - config = load_oauth_config() + self.access_info = {} - config[section][key] = value - with open(config_path, 'w') as cfg_file: - config.write(cfg_file) + self.token_expiration = 0 -def authorize(reddit): - config = load_oauth_config() + def get_config_fp(self): + HOME = os.path.expanduser('~') + XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', + os.path.join(HOME, '.config')) - settings = {} - if config.has_section('oauth'): - settings = dict(config.items('oauth')) + config_paths = [ + os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'), + os.path.join(HOME, '.rtv') + ] - scopes = ["edit", "history", "identity", "mysubreddits", "privatemessages", "read", "report", "save", "submit", "subscribe", "vote"] + # get the first existing config file + for config_path in config_paths: + if os.path.exists(config_path): + break - reddit.set_oauth_app_info(settings['client_id'], - settings['client_secret'], - settings['redirect_uri']) + return config_path - # Generate a random UUID - hex_uuid = uuid.uuid4().hex + def open_config(self, update=False): + if self.config_fp is None: + self.config_fp = self.get_config_fp() - permission_ask_page_link = reddit.get_authorize_url(str(hex_uuid), scope=scopes, refreshable=True) - input("You will now be redirected to your web browser. Press Enter to continue.") - webbrowser.open(permission_ask_page_link) + if update: + self.config.read(self.config_fp) - print("After allowing rtv app access, you will land on a page giving you a state and a code string. Please enter them here.") - final_state = input("State : ") - final_code = input("Code : ") + def save_config(self): + self.open_config() + with open(self.config_fp, 'w') as cfg: + self.config.write(cfg) - # Check if UUID matches obtained state - # (if not, authorization process is compromised, and I'm giving up) - if hex_uuid == final_state: - print("Obtained state matches UUID") - else: - print("Obtained state does not match UUID, stopping.") - return + def set_token_expiration(self): + self.token_expiration = time.time() + token_validity - # Get access information (authorization token) - info = reddit.get_access_information(final_code) - config['oauth']['authorization_token'] = info['access_token'] - config['oauth']['refresh_token'] = info['refresh_token'] - config['oauth']['scope'] = '-'.join(info['scope']) + def token_expired(self): + return time.time() > self.token_expiration - config_path = get_config_file_path() - with open(config_path, 'w') as cfg_file: - config.write(cfg_file) + def refresh(self, force=False): + if self.token_expired() or force: + try: + with self.loader(message='Refreshing token'): + new_access_info = self.reddit.refresh_access_information( + self.config['oauth']['refresh_token']) + self.access_info = new_access_info + self.reddit.set_access_credentials(scope=set(self.access_info['scope']), + access_token=self.access_info['access_token'], + refresh_token=self.access_info['refresh_token']) + self.set_token_expiration() + except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e: + show_notification(self.stdscr, ['Invalid OAuth data']) + else: + self.config['oauth']['access_token'] = self.access_info['access_token'] + self.config['oauth']['refresh_token'] = self.access_info['refresh_token'] + self.save_config() + + def authorize(self): + self.reddit.set_oauth_app_info(self.client_id, + self.client_secret, + self.redirect_uri) + + self.open_config(update=True) + # If no previous OAuth data found, starting from scratch + if 'oauth' not in self.config or 'access_token' not in self.config['oauth']: + # Generate a random UUID + hex_uuid = uuid.uuid4().hex + + permission_ask_page_link = self.reddit.get_authorize_url(str(hex_uuid), + scope=self.scope, refreshable=True) + + webbrowser.open(permission_ask_page_link) + show_notification(self.stdscr, ['Access prompt opened in web browser']) + + final_state = prompt_input(self.stdscr, 'State: ') + final_code = prompt_input(self.stdscr, 'Code: ') + + if not final_state or not final_code: + curses.flash() + return + + # Check if UUID matches obtained state + # (if not, authorization process is compromised, and I'm giving up) + if hex_uuid != final_state: + show_notification(self.stdscr, ['UUID mismatch, stopping.']) + return + + # Get access information (tokens and scopes) + self.access_info = self.reddit.get_access_information(final_code) + + try: + with self.loader(message='Logging in'): + self.reddit.set_access_credentials( + scope=set(self.access_info['scope']), + access_token=self.access_info['access_token'], + refresh_token=self.access_info['refresh_token']) + self.set_token_expiration() + except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e: + show_notification(self.stdscr, ['Invalid OAuth data']) + else: + if 'oauth' not in self.config: + self.config['oauth'] = {} + + self.config['oauth']['access_token'] = self.access_info['access_token'] + self.config['oauth']['refresh_token'] = self.access_info['refresh_token'] + self.save_config() + # Otherwise, fetch new access token + else: + self.refresh(force=True) diff --git a/rtv/page.py b/rtv/page.py index ef075d6..8f13565 100644 --- a/rtv/page.py +++ b/rtv/page.py @@ -12,6 +12,7 @@ from .helpers import open_editor from .curses_helpers import (Color, show_notification, show_help, prompt_input, add_line) from .docs import COMMENT_EDIT_FILE, SUBMISSION_FILE +from .oauth import OAuthTool __all__ = ['Navigator', 'BaseController', 'BasePage'] _logger = logging.getLogger(__name__) @@ -244,11 +245,12 @@ class BasePage(object): MIN_HEIGHT = 10 MIN_WIDTH = 20 - def __init__(self, stdscr, reddit, content, **kwargs): + def __init__(self, stdscr, reddit, content, oauth, **kwargs): self.stdscr = stdscr self.reddit = reddit self.content = content + self.oauth = oauth self.nav = Navigator(self.content.get, **kwargs) self._header_window = None @@ -312,6 +314,9 @@ class BasePage(object): @BaseController.register('a') def upvote(self): + # Refresh access token if expired + self.oauth.refresh() + data = self.content.get(self.nav.absolute_index) try: if 'likes' not in data: @@ -327,6 +332,9 @@ class BasePage(object): @BaseController.register('z') def downvote(self): + # Refresh access token if expired + self.oauth.refresh() + data = self.content.get(self.nav.absolute_index) try: if 'likes' not in data: @@ -348,23 +356,11 @@ class BasePage(object): account. """ - if self.reddit.is_logged_in(): - self.logout() + if self.reddit.is_oauth_session(): + self.reddit.clear_authentication() return - username = prompt_input(self.stdscr, 'Enter username:') - password = prompt_input(self.stdscr, 'Enter password:', hide=True) - if not username or not password: - curses.flash() - return - - try: - with self.loader(message='Logging in'): - self.reddit.login(username, password) - except praw.errors.InvalidUserPass: - show_notification(self.stdscr, ['Invalid user/pass']) - else: - show_notification(self.stdscr, ['Welcome {}'.format(username)]) + self.oauth.authorize() @BaseController.register('d') def delete(self): @@ -372,10 +368,13 @@ class BasePage(object): Delete a submission or comment. """ - if not self.reddit.is_logged_in(): + if not self.reddit.is_oauth_session(): show_notification(self.stdscr, ['Not logged in']) return + # Refresh access token if expired + self.oauth.refresh() + data = self.content.get(self.nav.absolute_index) if data.get('author') != self.reddit.user.name: curses.flash() @@ -400,10 +399,13 @@ class BasePage(object): Edit a submission or comment. """ - if not self.reddit.is_logged_in(): + if not self.reddit.is_oauth_session(): show_notification(self.stdscr, ['Not logged in']) return + # Refresh access token if expired + self.oauth.refresh() + data = self.content.get(self.nav.absolute_index) if data.get('author') != self.reddit.user.name: curses.flash() @@ -437,6 +439,10 @@ class BasePage(object): """ Checks the inbox for unread messages and displays a notification. """ + + # Refresh access token if expired + self.oauth.refresh() + inbox = len(list(self.reddit.get_unread(limit=1))) try: if inbox > 0: diff --git a/rtv/submission.py b/rtv/submission.py index 8731f0a..1086ce6 100644 --- a/rtv/submission.py +++ b/rtv/submission.py @@ -20,10 +20,11 @@ class SubmissionController(BaseController): class SubmissionPage(BasePage): - def __init__(self, stdscr, reddit, url=None, submission=None): + def __init__(self, stdscr, reddit, oauth, url=None, submission=None): self.controller = SubmissionController(self) self.loader = LoadScreen(stdscr) + self.oauth = oauth if url: content = SubmissionContent.from_url(reddit, url, self.loader) elif submission: @@ -32,7 +33,7 @@ class SubmissionPage(BasePage): raise ValueError('Must specify url or submission') super(SubmissionPage, self).__init__(stdscr, reddit, - content, page_index=-1) + content, oauth, page_index=-1) def loop(self): "Main control loop" @@ -88,10 +89,13 @@ class SubmissionPage(BasePage): selected comment. """ - if not self.reddit.is_logged_in(): + if not self.reddit.is_oauth_session(): show_notification(self.stdscr, ['Not logged in']) return + # Refresh access token if expired + self.oauth.refresh() + data = self.content.get(self.nav.absolute_index) if data['type'] == 'Submission': content = data['text'] @@ -127,6 +131,9 @@ class SubmissionPage(BasePage): def delete_comment(self): "Delete a comment as long as it is not the current submission" + # Refresh access token if expired + self.oauth.refresh() + if self.nav.absolute_index != -1: self.delete() else: diff --git a/rtv/subreddit.py b/rtv/subreddit.py index 7fc44e5..fcd1f07 100644 --- a/rtv/subreddit.py +++ b/rtv/subreddit.py @@ -33,13 +33,14 @@ class SubredditController(BaseController): class SubredditPage(BasePage): - def __init__(self, stdscr, reddit, name): + def __init__(self, stdscr, reddit, oauth, name): self.controller = SubredditController(self) self.loader = LoadScreen(stdscr) + self.oauth = oauth content = SubredditContent.from_name(reddit, name, self.loader) - super(SubredditPage, self).__init__(stdscr, reddit, content) + super(SubredditPage, self).__init__(stdscr, reddit, content, oauth) def loop(self): "Main control loop" @@ -53,6 +54,9 @@ class SubredditPage(BasePage): def refresh_content(self, name=None, order=None): "Re-download all submissions and reset the page index" + # Refresh access token if expired + self.oauth.refresh() + name = name or self.content.name order = order or self.content.order @@ -104,7 +108,7 @@ class SubredditPage(BasePage): "Select the current submission to view posts" data = self.content.get(self.nav.absolute_index) - page = SubmissionPage(self.stdscr, self.reddit, url=data['permalink']) + page = SubmissionPage(self.stdscr, self.reddit, self.oauth, url=data['permalink']) page.loop() if data['url_type'] == 'selfpost': global history @@ -119,7 +123,7 @@ class SubredditPage(BasePage): global history history.add(url) if data['url_type'] in ['x-post', 'selfpost']: - page = SubmissionPage(self.stdscr, self.reddit, url=url) + page = SubmissionPage(self.stdscr, self.reddit, self.oauth, url=url) page.loop() else: open_browser(url) @@ -128,10 +132,13 @@ class SubredditPage(BasePage): def post_submission(self): "Post a new submission to the given subreddit" - if not self.reddit.is_logged_in(): + if not self.reddit.is_oauth_session(): show_notification(self.stdscr, ['Not logged in']) return + # Refresh access token if expired + self.oauth.refresh() + # Strips the subreddit to just the name # Make sure it is a valid subreddit for submission subreddit = self.reddit.get_subreddit(self.content.name) @@ -161,7 +168,7 @@ class SubredditPage(BasePage): time.sleep(2.0) # Open the newly created post s.catch = False - page = SubmissionPage(self.stdscr, self.reddit, submission=post) + page = SubmissionPage(self.stdscr, self.reddit, self.oauth, submission=post) page.loop() self.refresh_content() @@ -169,12 +176,15 @@ class SubredditPage(BasePage): def open_subscriptions(self): "Open user subscriptions page" - if not self.reddit.is_logged_in() and not self.reddit.is_oauth_session(): + if not self.reddit.is_oauth_session(): show_notification(self.stdscr, ['Not logged in']) return + # Refresh access token if expired + self.oauth.refresh() + # Open subscriptions page - page = SubscriptionPage(self.stdscr, self.reddit) + page = SubscriptionPage(self.stdscr, self.reddit, self.oauth) page.loop() # When user has chosen a subreddit in the subscriptions list, diff --git a/rtv/subscriptions.py b/rtv/subscriptions.py index ae8981e..64e3a2a 100644 --- a/rtv/subscriptions.py +++ b/rtv/subscriptions.py @@ -15,14 +15,15 @@ class SubscriptionController(BaseController): class SubscriptionPage(BasePage): - def __init__(self, stdscr, reddit): + def __init__(self, stdscr, reddit, oauth): self.controller = SubscriptionController(self) self.loader = LoadScreen(stdscr) + self.oauth = oauth self.selected_subreddit_data = None content = SubscriptionContent.from_user(reddit, self.loader) - super(SubscriptionPage, self).__init__(stdscr, reddit, content) + super(SubscriptionPage, self).__init__(stdscr, reddit, content, oauth) def loop(self): "Main control loop" @@ -37,6 +38,9 @@ class SubscriptionPage(BasePage): def refresh_content(self): "Re-download all subscriptions and reset the page index" + # Refresh access token if expired + self.oauth.refresh() + self.content = SubscriptionContent.from_user(self.reddit, self.loader) self.nav = Navigator(self.content.get)