Refactoring and making rtv OAuth-compliant
This commit is contained in:
@@ -11,16 +11,16 @@ import configparser
|
||||
|
||||
from . import config
|
||||
from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError
|
||||
from .curses_helpers import curses_session
|
||||
from .curses_helpers import curses_session, LoadScreen
|
||||
from .submission import SubmissionPage
|
||||
from .subreddit import SubredditPage
|
||||
from .docs import *
|
||||
from .oauth import load_oauth_config, read_setting, write_setting, authorize
|
||||
from .oauth import OAuthTool
|
||||
from .__version__ import __version__
|
||||
|
||||
__all__ = []
|
||||
|
||||
def load_config():
|
||||
def open_config():
|
||||
"""
|
||||
Search for a configuration file at the location ~/.rtv and attempt to load
|
||||
saved settings for things like the username and password.
|
||||
@@ -41,6 +41,15 @@ def load_config():
|
||||
config.read(config_path)
|
||||
break
|
||||
|
||||
return config
|
||||
|
||||
def load_rtv_config():
|
||||
"""
|
||||
Attempt to load saved settings for things like the username and password.
|
||||
"""
|
||||
|
||||
config = open_config()
|
||||
|
||||
defaults = {}
|
||||
if config.has_section('rtv'):
|
||||
defaults = dict(config.items('rtv'))
|
||||
@@ -50,6 +59,18 @@ def load_config():
|
||||
|
||||
return defaults
|
||||
|
||||
def load_oauth_config():
|
||||
"""
|
||||
Attempt to load saved OAuth settings
|
||||
"""
|
||||
|
||||
config = open_config()
|
||||
|
||||
defaults = {}
|
||||
if config.has_section('oauth'):
|
||||
defaults = dict(config.items('oauth'))
|
||||
|
||||
return defaults
|
||||
|
||||
def command_line():
|
||||
|
||||
@@ -69,6 +90,13 @@ def command_line():
|
||||
group.add_argument('-u', dest='username', help='reddit username')
|
||||
group.add_argument('-p', dest='password', help='reddit password')
|
||||
|
||||
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
|
||||
oauth_group.add_argument('--client-id', dest='client_id', help='OAuth app ID')
|
||||
oauth_group.add_argument('--redurect-uri', dest='redirect_uri', help='OAuth app redirect URI')
|
||||
oauth_group.add_argument('--auth-token', dest='authorization_token', help='OAuth authorization token')
|
||||
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
|
||||
oauth_group.add_argument('--scope', dest='scope', help='OAuth app scope')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -81,7 +109,8 @@ def main():
|
||||
locale.setlocale(locale.LC_ALL, '')
|
||||
|
||||
args = command_line()
|
||||
local_config = load_config()
|
||||
local_rtv_config = load_rtv_config()
|
||||
local_oauth_config = load_oauth_config()
|
||||
|
||||
# set the terminal title
|
||||
title = 'rtv {0}'.format(__version__)
|
||||
@@ -92,10 +121,14 @@ def main():
|
||||
|
||||
# Fill in empty arguments with config file values. Paramaters explicitly
|
||||
# typed on the command line will take priority over config file params.
|
||||
for key, val in local_config.items():
|
||||
for key, val in local_rtv_config.items():
|
||||
if getattr(args, key, None) is None:
|
||||
setattr(args, key, val)
|
||||
|
||||
for k, v in local_oauth_config.items():
|
||||
if getattr(args, k, None) is None:
|
||||
setattr(args, k, v)
|
||||
|
||||
config.unicode = (not args.ascii)
|
||||
|
||||
# Squelch SSL warnings for Ubuntu
|
||||
@@ -107,34 +140,19 @@ def main():
|
||||
print('Connecting...')
|
||||
reddit = praw.Reddit(user_agent=AGENT)
|
||||
reddit.config.decode_html_entities = False
|
||||
if read_setting(key="authorization_token") is None:
|
||||
print('Hello OAuth login helper!')
|
||||
authorize(reddit)
|
||||
else:
|
||||
oauth_config = load_oauth_config()
|
||||
oauth_data = {}
|
||||
if oauth_config.has_section('oauth'):
|
||||
oauth_data = dict(oauth_config.items('oauth'))
|
||||
|
||||
reddit.set_oauth_app_info(oauth_data['client_id'],
|
||||
oauth_data['client_secret'],
|
||||
oauth_data['redirect_uri'])
|
||||
|
||||
reddit.set_access_credentials(scope=set(oauth_data['scope'].split('-')),
|
||||
access_token=oauth_data['authorization_token'],
|
||||
refresh_token=oauth_data['refresh_token'])
|
||||
"""if args.username:
|
||||
# PRAW will prompt for password if it is None
|
||||
reddit.login(args.username, args.password)"""
|
||||
with curses_session() as stdscr:
|
||||
oauth = OAuthTool(reddit, stdscr, LoadScreen(stdscr))
|
||||
oauth.authorize()
|
||||
if args.link:
|
||||
page = SubmissionPage(stdscr, reddit, url=args.link)
|
||||
page = SubmissionPage(stdscr, reddit, oauth, url=args.link)
|
||||
page.loop()
|
||||
subreddit = args.subreddit or 'front'
|
||||
page = SubredditPage(stdscr, reddit, subreddit)
|
||||
page = SubredditPage(stdscr, reddit, oauth, subreddit)
|
||||
page.loop()
|
||||
except praw.errors.InvalidUserPass:
|
||||
print('Invalid password for username: {}'.format(args.username))
|
||||
except praw.errors.OAuthAppRequired:
|
||||
print('Invalid OAuth app config parameters')
|
||||
except requests.ConnectionError:
|
||||
print('Connection timeout')
|
||||
except requests.HTTPError:
|
||||
@@ -150,7 +168,6 @@ def main():
|
||||
pass
|
||||
finally:
|
||||
# Ensure sockets are closed to prevent a ResourceWarning
|
||||
print(reddit.is_oauth_session())
|
||||
reddit.handler.http.close()
|
||||
|
||||
sys.exit(main())
|
||||
|
||||
@@ -2,4 +2,13 @@
|
||||
Global configuration settings
|
||||
"""
|
||||
|
||||
unicode = True
|
||||
unicode = True
|
||||
|
||||
"""
|
||||
OAuth settings
|
||||
"""
|
||||
|
||||
oauth_client_id = 'nxoobnwO7mCP5A'
|
||||
oauth_client_secret = 'praw_gapfill'
|
||||
oauth_redirect_uri = 'https://rtv.theo-piboubes.fr/auth'
|
||||
oauth_scope = 'edit-history-identity-mysubreddits-privatemessages-read-report-save-submit-subscribe-vote'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from .__version__ import __version__
|
||||
|
||||
__all__ = ['AGENT', 'SUMMARY', 'AUTH', 'CONTROLS', 'HELP', 'COMMENT_FILE',
|
||||
__all__ = ['AGENT', 'SUMMARY', 'AUTH', 'OAUTH', 'CONTROLS', 'HELP', 'COMMENT_FILE',
|
||||
'SUBMISSION_FILE', 'COMMENT_EDIT_FILE']
|
||||
|
||||
AGENT = """\
|
||||
@@ -17,6 +17,11 @@ Authenticating is required to vote and leave comments. If only a username is
|
||||
given, the program will display a secure prompt to enter a password.
|
||||
"""
|
||||
|
||||
OAUTH = """\
|
||||
Authentication is now done by OAuth, since PRAW will stop supporting login with
|
||||
username and password soon.
|
||||
"""
|
||||
|
||||
CONTROLS = """
|
||||
Controls
|
||||
--------
|
||||
|
||||
184
rtv/oauth.py
184
rtv/oauth.py
@@ -1,87 +1,145 @@
|
||||
import configparser
|
||||
import curses
|
||||
import logging
|
||||
import os
|
||||
import webbrowser
|
||||
import time
|
||||
import uuid
|
||||
import webbrowser
|
||||
|
||||
__all__ = []
|
||||
import praw
|
||||
|
||||
def get_config_file_path():
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config'))
|
||||
config_paths = [
|
||||
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
|
||||
os.path.join(HOME, '.rtv')
|
||||
]
|
||||
from . import config
|
||||
from .curses_helpers import show_notification, prompt_input
|
||||
|
||||
# get the first existing config file
|
||||
for config_path in config_paths:
|
||||
if os.path.exists(config_path):
|
||||
break
|
||||
__all__ = ['token_validity', 'OAuthTool']
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
return config_path
|
||||
token_validity = 3540
|
||||
|
||||
def load_oauth_config():
|
||||
config = configparser.ConfigParser()
|
||||
config_path = get_config_file_path()
|
||||
config.read(config_path)
|
||||
class OAuthTool(object):
|
||||
|
||||
return config
|
||||
def __init__(self, reddit, stdscr=None, loader=None,
|
||||
client_id=None, redirect_uri=None, scope=None):
|
||||
self.reddit = reddit
|
||||
self.stdscr = stdscr
|
||||
self.loader = loader
|
||||
|
||||
def read_setting(key, section='oauth'):
|
||||
config = load_oauth_config()
|
||||
self.config = configparser.ConfigParser()
|
||||
self.config_fp = None
|
||||
|
||||
try:
|
||||
setting = config[section][key]
|
||||
except KeyError:
|
||||
return None
|
||||
self.client_id = client_id or config.oauth_client_id
|
||||
# Comply with PRAW's desperate need for client secret
|
||||
self.client_secret = config.oauth_client_secret
|
||||
self.redirect_uri = redirect_uri or config.oauth_redirect_uri
|
||||
|
||||
return setting
|
||||
self.scope = scope or config.oauth_scope.split('-')
|
||||
|
||||
def write_setting(key, value, section='oauth'):
|
||||
config = load_oauth_config()
|
||||
self.access_info = {}
|
||||
|
||||
config[section][key] = value
|
||||
with open(config_path, 'w') as cfg_file:
|
||||
config.write(cfg_file)
|
||||
self.token_expiration = 0
|
||||
|
||||
def authorize(reddit):
|
||||
config = load_oauth_config()
|
||||
def get_config_fp(self):
|
||||
HOME = os.path.expanduser('~')
|
||||
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
|
||||
os.path.join(HOME, '.config'))
|
||||
|
||||
settings = {}
|
||||
if config.has_section('oauth'):
|
||||
settings = dict(config.items('oauth'))
|
||||
config_paths = [
|
||||
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
|
||||
os.path.join(HOME, '.rtv')
|
||||
]
|
||||
|
||||
scopes = ["edit", "history", "identity", "mysubreddits", "privatemessages", "read", "report", "save", "submit", "subscribe", "vote"]
|
||||
# get the first existing config file
|
||||
for config_path in config_paths:
|
||||
if os.path.exists(config_path):
|
||||
break
|
||||
|
||||
reddit.set_oauth_app_info(settings['client_id'],
|
||||
settings['client_secret'],
|
||||
settings['redirect_uri'])
|
||||
return config_path
|
||||
|
||||
# Generate a random UUID
|
||||
hex_uuid = uuid.uuid4().hex
|
||||
def open_config(self, update=False):
|
||||
if self.config_fp is None:
|
||||
self.config_fp = self.get_config_fp()
|
||||
|
||||
permission_ask_page_link = reddit.get_authorize_url(str(hex_uuid), scope=scopes, refreshable=True)
|
||||
input("You will now be redirected to your web browser. Press Enter to continue.")
|
||||
webbrowser.open(permission_ask_page_link)
|
||||
if update:
|
||||
self.config.read(self.config_fp)
|
||||
|
||||
print("After allowing rtv app access, you will land on a page giving you a state and a code string. Please enter them here.")
|
||||
final_state = input("State : ")
|
||||
final_code = input("Code : ")
|
||||
def save_config(self):
|
||||
self.open_config()
|
||||
with open(self.config_fp, 'w') as cfg:
|
||||
self.config.write(cfg)
|
||||
|
||||
# Check if UUID matches obtained state
|
||||
# (if not, authorization process is compromised, and I'm giving up)
|
||||
if hex_uuid == final_state:
|
||||
print("Obtained state matches UUID")
|
||||
else:
|
||||
print("Obtained state does not match UUID, stopping.")
|
||||
return
|
||||
def set_token_expiration(self):
|
||||
self.token_expiration = time.time() + token_validity
|
||||
|
||||
# Get access information (authorization token)
|
||||
info = reddit.get_access_information(final_code)
|
||||
config['oauth']['authorization_token'] = info['access_token']
|
||||
config['oauth']['refresh_token'] = info['refresh_token']
|
||||
config['oauth']['scope'] = '-'.join(info['scope'])
|
||||
def token_expired(self):
|
||||
return time.time() > self.token_expiration
|
||||
|
||||
config_path = get_config_file_path()
|
||||
with open(config_path, 'w') as cfg_file:
|
||||
config.write(cfg_file)
|
||||
def refresh(self, force=False):
|
||||
if self.token_expired() or force:
|
||||
try:
|
||||
with self.loader(message='Refreshing token'):
|
||||
new_access_info = self.reddit.refresh_access_information(
|
||||
self.config['oauth']['refresh_token'])
|
||||
self.access_info = new_access_info
|
||||
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
self.config['oauth']['access_token'] = self.access_info['access_token']
|
||||
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
|
||||
self.save_config()
|
||||
|
||||
def authorize(self):
|
||||
self.reddit.set_oauth_app_info(self.client_id,
|
||||
self.client_secret,
|
||||
self.redirect_uri)
|
||||
|
||||
self.open_config(update=True)
|
||||
# If no previous OAuth data found, starting from scratch
|
||||
if 'oauth' not in self.config or 'access_token' not in self.config['oauth']:
|
||||
# Generate a random UUID
|
||||
hex_uuid = uuid.uuid4().hex
|
||||
|
||||
permission_ask_page_link = self.reddit.get_authorize_url(str(hex_uuid),
|
||||
scope=self.scope, refreshable=True)
|
||||
|
||||
webbrowser.open(permission_ask_page_link)
|
||||
show_notification(self.stdscr, ['Access prompt opened in web browser'])
|
||||
|
||||
final_state = prompt_input(self.stdscr, 'State: ')
|
||||
final_code = prompt_input(self.stdscr, 'Code: ')
|
||||
|
||||
if not final_state or not final_code:
|
||||
curses.flash()
|
||||
return
|
||||
|
||||
# Check if UUID matches obtained state
|
||||
# (if not, authorization process is compromised, and I'm giving up)
|
||||
if hex_uuid != final_state:
|
||||
show_notification(self.stdscr, ['UUID mismatch, stopping.'])
|
||||
return
|
||||
|
||||
# Get access information (tokens and scopes)
|
||||
self.access_info = self.reddit.get_access_information(final_code)
|
||||
|
||||
try:
|
||||
with self.loader(message='Logging in'):
|
||||
self.reddit.set_access_credentials(
|
||||
scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
if 'oauth' not in self.config:
|
||||
self.config['oauth'] = {}
|
||||
|
||||
self.config['oauth']['access_token'] = self.access_info['access_token']
|
||||
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
|
||||
self.save_config()
|
||||
# Otherwise, fetch new access token
|
||||
else:
|
||||
self.refresh(force=True)
|
||||
|
||||
42
rtv/page.py
42
rtv/page.py
@@ -12,6 +12,7 @@ from .helpers import open_editor
|
||||
from .curses_helpers import (Color, show_notification, show_help, prompt_input,
|
||||
add_line)
|
||||
from .docs import COMMENT_EDIT_FILE, SUBMISSION_FILE
|
||||
from .oauth import OAuthTool
|
||||
|
||||
__all__ = ['Navigator', 'BaseController', 'BasePage']
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -244,11 +245,12 @@ class BasePage(object):
|
||||
MIN_HEIGHT = 10
|
||||
MIN_WIDTH = 20
|
||||
|
||||
def __init__(self, stdscr, reddit, content, **kwargs):
|
||||
def __init__(self, stdscr, reddit, content, oauth, **kwargs):
|
||||
|
||||
self.stdscr = stdscr
|
||||
self.reddit = reddit
|
||||
self.content = content
|
||||
self.oauth = oauth
|
||||
self.nav = Navigator(self.content.get, **kwargs)
|
||||
|
||||
self._header_window = None
|
||||
@@ -312,6 +314,9 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('a')
|
||||
def upvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -327,6 +332,9 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('z')
|
||||
def downvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -348,23 +356,11 @@ class BasePage(object):
|
||||
account.
|
||||
"""
|
||||
|
||||
if self.reddit.is_logged_in():
|
||||
self.logout()
|
||||
if self.reddit.is_oauth_session():
|
||||
self.reddit.clear_authentication()
|
||||
return
|
||||
|
||||
username = prompt_input(self.stdscr, 'Enter username:')
|
||||
password = prompt_input(self.stdscr, 'Enter password:', hide=True)
|
||||
if not username or not password:
|
||||
curses.flash()
|
||||
return
|
||||
|
||||
try:
|
||||
with self.loader(message='Logging in'):
|
||||
self.reddit.login(username, password)
|
||||
except praw.errors.InvalidUserPass:
|
||||
show_notification(self.stdscr, ['Invalid user/pass'])
|
||||
else:
|
||||
show_notification(self.stdscr, ['Welcome {}'.format(username)])
|
||||
self.oauth.authorize()
|
||||
|
||||
@BaseController.register('d')
|
||||
def delete(self):
|
||||
@@ -372,10 +368,13 @@ class BasePage(object):
|
||||
Delete a submission or comment.
|
||||
"""
|
||||
|
||||
if not self.reddit.is_logged_in():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -400,10 +399,13 @@ class BasePage(object):
|
||||
Edit a submission or comment.
|
||||
"""
|
||||
|
||||
if not self.reddit.is_logged_in():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -437,6 +439,10 @@ class BasePage(object):
|
||||
"""
|
||||
Checks the inbox for unread messages and displays a notification.
|
||||
"""
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
inbox = len(list(self.reddit.get_unread(limit=1)))
|
||||
try:
|
||||
if inbox > 0:
|
||||
|
||||
@@ -20,10 +20,11 @@ class SubmissionController(BaseController):
|
||||
|
||||
class SubmissionPage(BasePage):
|
||||
|
||||
def __init__(self, stdscr, reddit, url=None, submission=None):
|
||||
def __init__(self, stdscr, reddit, oauth, url=None, submission=None):
|
||||
|
||||
self.controller = SubmissionController(self)
|
||||
self.loader = LoadScreen(stdscr)
|
||||
self.oauth = oauth
|
||||
if url:
|
||||
content = SubmissionContent.from_url(reddit, url, self.loader)
|
||||
elif submission:
|
||||
@@ -32,7 +33,7 @@ class SubmissionPage(BasePage):
|
||||
raise ValueError('Must specify url or submission')
|
||||
|
||||
super(SubmissionPage, self).__init__(stdscr, reddit,
|
||||
content, page_index=-1)
|
||||
content, oauth, page_index=-1)
|
||||
|
||||
def loop(self):
|
||||
"Main control loop"
|
||||
@@ -88,10 +89,13 @@ class SubmissionPage(BasePage):
|
||||
selected comment.
|
||||
"""
|
||||
|
||||
if not self.reddit.is_logged_in():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data['type'] == 'Submission':
|
||||
content = data['text']
|
||||
@@ -127,6 +131,9 @@ class SubmissionPage(BasePage):
|
||||
def delete_comment(self):
|
||||
"Delete a comment as long as it is not the current submission"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
if self.nav.absolute_index != -1:
|
||||
self.delete()
|
||||
else:
|
||||
|
||||
@@ -33,13 +33,14 @@ class SubredditController(BaseController):
|
||||
|
||||
class SubredditPage(BasePage):
|
||||
|
||||
def __init__(self, stdscr, reddit, name):
|
||||
def __init__(self, stdscr, reddit, oauth, name):
|
||||
|
||||
self.controller = SubredditController(self)
|
||||
self.loader = LoadScreen(stdscr)
|
||||
self.oauth = oauth
|
||||
|
||||
content = SubredditContent.from_name(reddit, name, self.loader)
|
||||
super(SubredditPage, self).__init__(stdscr, reddit, content)
|
||||
super(SubredditPage, self).__init__(stdscr, reddit, content, oauth)
|
||||
|
||||
def loop(self):
|
||||
"Main control loop"
|
||||
@@ -53,6 +54,9 @@ class SubredditPage(BasePage):
|
||||
def refresh_content(self, name=None, order=None):
|
||||
"Re-download all submissions and reset the page index"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
name = name or self.content.name
|
||||
order = order or self.content.order
|
||||
|
||||
@@ -104,7 +108,7 @@ class SubredditPage(BasePage):
|
||||
"Select the current submission to view posts"
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
page = SubmissionPage(self.stdscr, self.reddit, url=data['permalink'])
|
||||
page = SubmissionPage(self.stdscr, self.reddit, self.oauth, url=data['permalink'])
|
||||
page.loop()
|
||||
if data['url_type'] == 'selfpost':
|
||||
global history
|
||||
@@ -119,7 +123,7 @@ class SubredditPage(BasePage):
|
||||
global history
|
||||
history.add(url)
|
||||
if data['url_type'] in ['x-post', 'selfpost']:
|
||||
page = SubmissionPage(self.stdscr, self.reddit, url=url)
|
||||
page = SubmissionPage(self.stdscr, self.reddit, self.oauth, url=url)
|
||||
page.loop()
|
||||
else:
|
||||
open_browser(url)
|
||||
@@ -128,10 +132,13 @@ class SubredditPage(BasePage):
|
||||
def post_submission(self):
|
||||
"Post a new submission to the given subreddit"
|
||||
|
||||
if not self.reddit.is_logged_in():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
# Strips the subreddit to just the name
|
||||
# Make sure it is a valid subreddit for submission
|
||||
subreddit = self.reddit.get_subreddit(self.content.name)
|
||||
@@ -161,7 +168,7 @@ class SubredditPage(BasePage):
|
||||
time.sleep(2.0)
|
||||
# Open the newly created post
|
||||
s.catch = False
|
||||
page = SubmissionPage(self.stdscr, self.reddit, submission=post)
|
||||
page = SubmissionPage(self.stdscr, self.reddit, self.oauth, submission=post)
|
||||
page.loop()
|
||||
self.refresh_content()
|
||||
|
||||
@@ -169,12 +176,15 @@ class SubredditPage(BasePage):
|
||||
def open_subscriptions(self):
|
||||
"Open user subscriptions page"
|
||||
|
||||
if not self.reddit.is_logged_in() and not self.reddit.is_oauth_session():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
# Open subscriptions page
|
||||
page = SubscriptionPage(self.stdscr, self.reddit)
|
||||
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
|
||||
page.loop()
|
||||
|
||||
# When user has chosen a subreddit in the subscriptions list,
|
||||
|
||||
@@ -15,14 +15,15 @@ class SubscriptionController(BaseController):
|
||||
|
||||
class SubscriptionPage(BasePage):
|
||||
|
||||
def __init__(self, stdscr, reddit):
|
||||
def __init__(self, stdscr, reddit, oauth):
|
||||
|
||||
self.controller = SubscriptionController(self)
|
||||
self.loader = LoadScreen(stdscr)
|
||||
self.oauth = oauth
|
||||
self.selected_subreddit_data = None
|
||||
|
||||
content = SubscriptionContent.from_user(reddit, self.loader)
|
||||
super(SubscriptionPage, self).__init__(stdscr, reddit, content)
|
||||
super(SubscriptionPage, self).__init__(stdscr, reddit, content, oauth)
|
||||
|
||||
def loop(self):
|
||||
"Main control loop"
|
||||
@@ -37,6 +38,9 @@ class SubscriptionPage(BasePage):
|
||||
def refresh_content(self):
|
||||
"Re-download all subscriptions and reset the page index"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
|
||||
self.nav = Navigator(self.content.get)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user