diff --git a/rtv/__main__.py b/rtv/__main__.py index 8d6e561..b084de4 100644 --- a/rtv/__main__.py +++ b/rtv/__main__.py @@ -7,7 +7,7 @@ import logging import requests import praw import praw.errors -from six.moves import configparser +import configparser from . import config from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError @@ -15,6 +15,7 @@ from .curses_helpers import curses_session from .submission import SubmissionPage from .subreddit import SubredditPage from .docs import * +from .oauth import load_oauth_config, read_setting, write_setting, authorize from .__version__ import __version__ __all__ = [] @@ -106,9 +107,25 @@ def main(): print('Connecting...') reddit = praw.Reddit(user_agent=AGENT) reddit.config.decode_html_entities = False - if args.username: + if read_setting(key="authorization_token") is None: + print('Hello OAuth login helper!') + authorize(reddit) + else: + oauth_config = load_oauth_config() + oauth_data = {} + if oauth_config.has_section('oauth'): + oauth_data = dict(oauth_config.items('oauth')) + + reddit.set_oauth_app_info(oauth_data['client_id'], + oauth_data['client_secret'], + oauth_data['redirect_uri']) + + reddit.set_access_credentials(scope=set(oauth_data['scope'].split('-')), + access_token=oauth_data['authorization_token'], + refresh_token=oauth_data['refresh_token']) + """if args.username: # PRAW will prompt for password if it is None - reddit.login(args.username, args.password) + reddit.login(args.username, args.password)""" with curses_session() as stdscr: if args.link: page = SubmissionPage(stdscr, reddit, url=args.link) @@ -133,6 +150,7 @@ def main(): pass finally: # Ensure sockets are closed to prevent a ResourceWarning + print(reddit.is_oauth_session()) reddit.handler.http.close() sys.exit(main()) diff --git a/rtv/oauth.py b/rtv/oauth.py new file mode 100644 index 0000000..868b69f --- /dev/null +++ b/rtv/oauth.py @@ -0,0 +1,87 @@ +import configparser +import os +import webbrowser +import uuid + +__all__ = [] + +def get_config_file_path(): + HOME = os.path.expanduser('~') + XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config')) + config_paths = [ + os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'), + os.path.join(HOME, '.rtv') + ] + + # get the first existing config file + for config_path in config_paths: + if os.path.exists(config_path): + break + + return config_path + +def load_oauth_config(): + config = configparser.ConfigParser() + config_path = get_config_file_path() + config.read(config_path) + + return config + +def read_setting(key, section='oauth'): + config = load_oauth_config() + + try: + setting = config[section][key] + except KeyError: + return None + + return setting + +def write_setting(key, value, section='oauth'): + config = load_oauth_config() + + config[section][key] = value + with open(config_path, 'w') as cfg_file: + config.write(cfg_file) + +def authorize(reddit): + config = load_oauth_config() + + settings = {} + if config.has_section('oauth'): + settings = dict(config.items('oauth')) + + scopes = ["edit", "history", "identity", "mysubreddits", "privatemessages", "read", "report", "save", "submit", "subscribe", "vote"] + + reddit.set_oauth_app_info(settings['client_id'], + settings['client_secret'], + settings['redirect_uri']) + + # Generate a random UUID + hex_uuid = uuid.uuid4().hex + + permission_ask_page_link = reddit.get_authorize_url(str(hex_uuid), scope=scopes, refreshable=True) + input("You will now be redirected to your web browser. Press Enter to continue.") + webbrowser.open(permission_ask_page_link) + + print("After allowing rtv app access, you will land on a page giving you a state and a code string. Please enter them here.") + final_state = input("State : ") + final_code = input("Code : ") + + # Check if UUID matches obtained state + # (if not, authorization process is compromised, and I'm giving up) + if hex_uuid == final_state: + print("Obtained state matches UUID") + else: + print("Obtained state does not match UUID, stopping.") + return + + # Get access information (authorization token) + info = reddit.get_access_information(final_code) + config['oauth']['authorization_token'] = info['access_token'] + config['oauth']['refresh_token'] = info['refresh_token'] + config['oauth']['scope'] = '-'.join(info['scope']) + + config_path = get_config_file_path() + with open(config_path, 'w') as cfg_file: + config.write(cfg_file) diff --git a/rtv/subreddit.py b/rtv/subreddit.py index a869a3b..7fc44e5 100644 --- a/rtv/subreddit.py +++ b/rtv/subreddit.py @@ -169,7 +169,7 @@ class SubredditPage(BasePage): def open_subscriptions(self): "Open user subscriptions page" - if not self.reddit.is_logged_in(): + if not self.reddit.is_logged_in() and not self.reddit.is_oauth_session(): show_notification(self.stdscr, ['Not logged in']) return