Refactoring and making rtv OAuth-compliant
This commit is contained in:
184
rtv/oauth.py
184
rtv/oauth.py
@@ -1,87 +1,145 @@
|
||||
import configparser
|
||||
import curses
|
||||
import logging
|
||||
import os
|
||||
import webbrowser
|
||||
import time
|
||||
import uuid
|
||||
import webbrowser
|
||||
|
||||
__all__ = []
|
||||
import praw
|
||||
|
||||
def get_config_file_path():
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config'))
|
||||
config_paths = [
|
||||
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
|
||||
os.path.join(HOME, '.rtv')
|
||||
]
|
||||
from . import config
|
||||
from .curses_helpers import show_notification, prompt_input
|
||||
|
||||
# get the first existing config file
|
||||
for config_path in config_paths:
|
||||
if os.path.exists(config_path):
|
||||
break
|
||||
__all__ = ['token_validity', 'OAuthTool']
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
return config_path
|
||||
token_validity = 3540
|
||||
|
||||
def load_oauth_config():
|
||||
config = configparser.ConfigParser()
|
||||
config_path = get_config_file_path()
|
||||
config.read(config_path)
|
||||
class OAuthTool(object):
|
||||
|
||||
return config
|
||||
def __init__(self, reddit, stdscr=None, loader=None,
|
||||
client_id=None, redirect_uri=None, scope=None):
|
||||
self.reddit = reddit
|
||||
self.stdscr = stdscr
|
||||
self.loader = loader
|
||||
|
||||
def read_setting(key, section='oauth'):
|
||||
config = load_oauth_config()
|
||||
self.config = configparser.ConfigParser()
|
||||
self.config_fp = None
|
||||
|
||||
try:
|
||||
setting = config[section][key]
|
||||
except KeyError:
|
||||
return None
|
||||
self.client_id = client_id or config.oauth_client_id
|
||||
# Comply with PRAW's desperate need for client secret
|
||||
self.client_secret = config.oauth_client_secret
|
||||
self.redirect_uri = redirect_uri or config.oauth_redirect_uri
|
||||
|
||||
return setting
|
||||
self.scope = scope or config.oauth_scope.split('-')
|
||||
|
||||
def write_setting(key, value, section='oauth'):
|
||||
config = load_oauth_config()
|
||||
self.access_info = {}
|
||||
|
||||
config[section][key] = value
|
||||
with open(config_path, 'w') as cfg_file:
|
||||
config.write(cfg_file)
|
||||
self.token_expiration = 0
|
||||
|
||||
def authorize(reddit):
|
||||
config = load_oauth_config()
|
||||
def get_config_fp(self):
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
|
||||
os.path.join(HOME, '.config'))
|
||||
|
||||
settings = {}
|
||||
if config.has_section('oauth'):
|
||||
settings = dict(config.items('oauth'))
|
||||
config_paths = [
|
||||
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
|
||||
os.path.join(HOME, '.rtv')
|
||||
]
|
||||
|
||||
scopes = ["edit", "history", "identity", "mysubreddits", "privatemessages", "read", "report", "save", "submit", "subscribe", "vote"]
|
||||
# get the first existing config file
|
||||
for config_path in config_paths:
|
||||
if os.path.exists(config_path):
|
||||
break
|
||||
|
||||
reddit.set_oauth_app_info(settings['client_id'],
|
||||
settings['client_secret'],
|
||||
settings['redirect_uri'])
|
||||
return config_path
|
||||
|
||||
# Generate a random UUID
|
||||
hex_uuid = uuid.uuid4().hex
|
||||
def open_config(self, update=False):
|
||||
if self.config_fp is None:
|
||||
self.config_fp = self.get_config_fp()
|
||||
|
||||
permission_ask_page_link = reddit.get_authorize_url(str(hex_uuid), scope=scopes, refreshable=True)
|
||||
input("You will now be redirected to your web browser. Press Enter to continue.")
|
||||
webbrowser.open(permission_ask_page_link)
|
||||
if update:
|
||||
self.config.read(self.config_fp)
|
||||
|
||||
print("After allowing rtv app access, you will land on a page giving you a state and a code string. Please enter them here.")
|
||||
final_state = input("State : ")
|
||||
final_code = input("Code : ")
|
||||
def save_config(self):
|
||||
self.open_config()
|
||||
with open(self.config_fp, 'w') as cfg:
|
||||
self.config.write(cfg)
|
||||
|
||||
# Check if UUID matches obtained state
|
||||
# (if not, authorization process is compromised, and I'm giving up)
|
||||
if hex_uuid == final_state:
|
||||
print("Obtained state matches UUID")
|
||||
else:
|
||||
print("Obtained state does not match UUID, stopping.")
|
||||
return
|
||||
def set_token_expiration(self):
|
||||
self.token_expiration = time.time() + token_validity
|
||||
|
||||
# Get access information (authorization token)
|
||||
info = reddit.get_access_information(final_code)
|
||||
config['oauth']['authorization_token'] = info['access_token']
|
||||
config['oauth']['refresh_token'] = info['refresh_token']
|
||||
config['oauth']['scope'] = '-'.join(info['scope'])
|
||||
def token_expired(self):
|
||||
return time.time() > self.token_expiration
|
||||
|
||||
config_path = get_config_file_path()
|
||||
with open(config_path, 'w') as cfg_file:
|
||||
config.write(cfg_file)
|
||||
def refresh(self, force=False):
|
||||
if self.token_expired() or force:
|
||||
try:
|
||||
with self.loader(message='Refreshing token'):
|
||||
new_access_info = self.reddit.refresh_access_information(
|
||||
self.config['oauth']['refresh_token'])
|
||||
self.access_info = new_access_info
|
||||
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
self.config['oauth']['access_token'] = self.access_info['access_token']
|
||||
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
|
||||
self.save_config()
|
||||
|
||||
def authorize(self):
|
||||
self.reddit.set_oauth_app_info(self.client_id,
|
||||
self.client_secret,
|
||||
self.redirect_uri)
|
||||
|
||||
self.open_config(update=True)
|
||||
# If no previous OAuth data found, starting from scratch
|
||||
if 'oauth' not in self.config or 'access_token' not in self.config['oauth']:
|
||||
# Generate a random UUID
|
||||
hex_uuid = uuid.uuid4().hex
|
||||
|
||||
permission_ask_page_link = self.reddit.get_authorize_url(str(hex_uuid),
|
||||
scope=self.scope, refreshable=True)
|
||||
|
||||
webbrowser.open(permission_ask_page_link)
|
||||
show_notification(self.stdscr, ['Access prompt opened in web browser'])
|
||||
|
||||
final_state = prompt_input(self.stdscr, 'State: ')
|
||||
final_code = prompt_input(self.stdscr, 'Code: ')
|
||||
|
||||
if not final_state or not final_code:
|
||||
curses.flash()
|
||||
return
|
||||
|
||||
# Check if UUID matches obtained state
|
||||
# (if not, authorization process is compromised, and I'm giving up)
|
||||
if hex_uuid != final_state:
|
||||
show_notification(self.stdscr, ['UUID mismatch, stopping.'])
|
||||
return
|
||||
|
||||
# Get access information (tokens and scopes)
|
||||
self.access_info = self.reddit.get_access_information(final_code)
|
||||
|
||||
try:
|
||||
with self.loader(message='Logging in'):
|
||||
self.reddit.set_access_credentials(
|
||||
scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
if 'oauth' not in self.config:
|
||||
self.config['oauth'] = {}
|
||||
|
||||
self.config['oauth']['access_token'] = self.access_info['access_token']
|
||||
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
|
||||
self.save_config()
|
||||
# Otherwise, fetch new access token
|
||||
else:
|
||||
self.refresh(force=True)
|
||||
|
||||
Reference in New Issue
Block a user