Merge branch 'master' of https://github.com/TheoPib/rtv into TheoPib-master

This commit is contained in:
Michael Lazar
2015-08-27 00:14:04 -07:00
14 changed files with 398 additions and 103 deletions

View File

@@ -1,2 +1,3 @@
include version.py include version.py
include CHANGELOG.rst CONTRIBUTORS.rst LICENSE include CHANGELOG.rst CONTRIBUTORS.rst LICENSE
include rtv/templates/*.html

View File

@@ -84,7 +84,7 @@ Once you are logged in your username will appear in the top-right corner of the
:``c``: Compose a new post or comment :``c``: Compose a new post or comment
:``e``: Edit an existing post or comment :``e``: Edit an existing post or comment
:``d``: Delete an existing post or comment :``d``: Delete an existing post or comment
:``s``: Open subscribed subreddits list :``s``: Open/close subscribed subreddits list
-------------- --------------
Subreddit Mode Subreddit Mode
@@ -152,14 +152,18 @@ RTV will read a configuration placed at ``~/.config/rtv/rtv.cfg`` (or ``$XDG_CON
Each line in the file will replace the corresponding default argument in the launch script. Each line in the file will replace the corresponding default argument in the launch script.
This can be used to avoid having to re-enter login credentials every time the program is launched. This can be used to avoid having to re-enter login credentials every time the program is launched.
Example config: The OAuth section contains a boolean to trigger auto-login (defaults to False).
When authenticated, two additional fields are written : **access_token** and **refresh_token**.
Those are basically like username and password : they are used to authenticate you on Reddit servers.
Example initial config:
.. code-block:: ini .. code-block:: ini
[rtv] [oauth]
username=MyUsername auto_login=False
password=MySecretPassword
[rtv]
# Log file location # Log file location
log=/tmp/rtv.log log=/tmp/rtv.log

View File

@@ -11,15 +11,35 @@ from six.moves import configparser
from . import config from . import config
from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError from .exceptions import SubmissionError, SubredditError, SubscriptionError, ProgramError
from .curses_helpers import curses_session from .curses_helpers import curses_session, LoadScreen
from .submission import SubmissionPage from .submission import SubmissionPage
from .subreddit import SubredditPage from .subreddit import SubredditPage
from .docs import * from .docs import *
from .oauth import OAuthTool
from .__version__ import __version__ from .__version__ import __version__
from tornado import ioloop
__all__ = [] __all__ = []
def load_config(): def get_config_fp():
HOME = os.path.expanduser('~')
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
os.path.join(HOME, '.config'))
config_paths = [
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
os.path.join(HOME, '.rtv')
]
# get the first existing config file
for config_path in config_paths:
if os.path.exists(config_path):
break
return config_path
def open_config():
""" """
Search for a configuration file at the location ~/.rtv and attempt to load Search for a configuration file at the location ~/.rtv and attempt to load
saved settings for things like the username and password. saved settings for things like the username and password.
@@ -27,18 +47,17 @@ def load_config():
config = configparser.ConfigParser() config = configparser.ConfigParser()
HOME = os.path.expanduser('~') config_path = get_config_fp()
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME', os.path.join(HOME, '.config')) config.read(config_path)
config_paths = [
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
os.path.join(HOME, '.rtv')
]
# read only the first existing config file return config
for config_path in config_paths:
if os.path.exists(config_path): def load_rtv_config():
config.read(config_path) """
break Attempt to load saved settings for things like the username and password.
"""
config = open_config()
defaults = {} defaults = {}
if config.has_section('rtv'): if config.has_section('rtv'):
@@ -49,6 +68,23 @@ def load_config():
return defaults return defaults
def load_oauth_config():
"""
Attempt to load saved OAuth settings
"""
config = open_config()
if config.has_section('oauth'):
defaults = dict(config.items('oauth'))
else:
# Populate OAuth section
config['oauth'] = {'auto_login': False}
with open(get_config_fp(), 'w') as cfg:
config.write(cfg)
defaults = dict(config.items('oauth'))
return defaults
def command_line(): def command_line():
@@ -68,6 +104,11 @@ def command_line():
group.add_argument('-u', dest='username', help='reddit username') group.add_argument('-u', dest='username', help='reddit username')
group.add_argument('-p', dest='password', help='reddit password') group.add_argument('-p', dest='password', help='reddit password')
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
oauth_group.add_argument('--auth-token', dest='access_token', help='OAuth authorization token')
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -80,7 +121,8 @@ def main():
locale.setlocale(locale.LC_ALL, '') locale.setlocale(locale.LC_ALL, '')
args = command_line() args = command_line()
local_config = load_config() local_rtv_config = load_rtv_config()
local_oauth_config = load_oauth_config()
# set the terminal title # set the terminal title
title = 'rtv {0}'.format(__version__) title = 'rtv {0}'.format(__version__)
@@ -91,10 +133,14 @@ def main():
# Fill in empty arguments with config file values. Paramaters explicitly # Fill in empty arguments with config file values. Paramaters explicitly
# typed on the command line will take priority over config file params. # typed on the command line will take priority over config file params.
for key, val in local_config.items(): for key, val in local_rtv_config.items():
if getattr(args, key, None) is None: if getattr(args, key, None) is None:
setattr(args, key, val) setattr(args, key, val)
for k, v in local_oauth_config.items():
if getattr(args, k, None) is None:
setattr(args, k, v)
config.unicode = (not args.ascii) config.unicode = (not args.ascii)
# Squelch SSL warnings for Ubuntu # Squelch SSL warnings for Ubuntu
@@ -106,18 +152,19 @@ def main():
print('Connecting...') print('Connecting...')
reddit = praw.Reddit(user_agent=AGENT) reddit = praw.Reddit(user_agent=AGENT)
reddit.config.decode_html_entities = False reddit.config.decode_html_entities = False
if args.username:
# PRAW will prompt for password if it is None
reddit.login(args.username, args.password)
with curses_session() as stdscr: with curses_session() as stdscr:
oauth = OAuthTool(reddit, stdscr, LoadScreen(stdscr))
if args.auto_login == 'True': # Ew!
oauth.authorize()
if args.link: if args.link:
page = SubmissionPage(stdscr, reddit, url=args.link) page = SubmissionPage(stdscr, reddit, oauth, url=args.link)
page.loop() page.loop()
subreddit = args.subreddit or 'front' subreddit = args.subreddit or 'front'
page = SubredditPage(stdscr, reddit, subreddit) page = SubredditPage(stdscr, reddit, oauth, subreddit)
page.loop() page.loop()
except praw.errors.InvalidUserPass: except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
print('Invalid password for username: {}'.format(args.username)) praw.errors.HTTPException) as e:
print('Invalid OAuth data')
except requests.ConnectionError: except requests.ConnectionError:
print('Connection timeout') print('Connection timeout')
except requests.HTTPError: except requests.HTTPError:
@@ -134,5 +181,7 @@ def main():
finally: finally:
# Ensure sockets are closed to prevent a ResourceWarning # Ensure sockets are closed to prevent a ResourceWarning
reddit.handler.http.close() reddit.handler.http.close()
# Explicitly close file descriptors opened by Tornado's IOLoop
ioloop.IOLoop.current().close(all_fds=True)
sys.exit(main()) sys.exit(main())

View File

@@ -2,4 +2,13 @@
Global configuration settings Global configuration settings
""" """
unicode = True unicode = True
"""
OAuth settings
"""
oauth_client_id = 'nxoobnwO7mCP5A'
oauth_client_secret = 'praw_gapfill'
oauth_redirect_uri = 'http://127.0.0.1:65000/auth'
oauth_scope = 'edit-history-identity-mysubreddits-privatemessages-read-report-save-submit-subscribe-vote'

View File

@@ -159,7 +159,7 @@ class BaseContent(object):
data = {} data = {}
data['object'] = subscription data['object'] = subscription
data['type'] = 'Subscription' data['type'] = 'Subscription'
data['name'] = "/r/" + subscription._case_name data['name'] = "/r/" + subscription.display_name
data['title'] = subscription.title data['title'] = subscription.title
return data return data
@@ -385,14 +385,17 @@ class SubredditContent(BaseContent):
return data return data
class SubscriptionContent(BaseContent): class SubscriptionContent(BaseContent):
def __init__(self, subscriptions, loader): def __init__(self, subscriptions, loader):
self.name = "Subscriptions" self.name = "Subscriptions"
self.order = None
self._loader = loader self._loader = loader
self._subscriptions = subscriptions self._subscriptions = subscriptions
self._subscription_data = [] self._subscription_data = []
@classmethod @classmethod
def get_list(cls, reddit, loader): def from_user(cls, reddit, loader):
try: try:
with loader(): with loader():
subscriptions = reddit.get_my_subreddits(limit=None) subscriptions = reddit.get_my_subreddits(limit=None)
@@ -421,7 +424,7 @@ class SubscriptionContent(BaseContent):
self._subscription_data.append(data) self._subscription_data.append(data)
data = self._subscription_data[index] data = self._subscription_data[index]
data['split_title'] = wrap_text(data['name'], width=n_cols) data['split_title'] = wrap_text(data['title'], width=n_cols)
data['n_rows'] = len(data['split_title']) + 1 data['n_rows'] = len(data['split_title']) + 1
data['offset'] = 0 data['offset'] = 0

View File

@@ -1,6 +1,6 @@
from .__version__ import __version__ from .__version__ import __version__
__all__ = ['AGENT', 'SUMMARY', 'AUTH', 'CONTROLS', 'HELP', 'COMMENT_FILE', __all__ = ['AGENT', 'SUMMARY', 'AUTH', 'OAUTH', 'CONTROLS', 'HELP', 'COMMENT_FILE',
'SUBMISSION_FILE', 'COMMENT_EDIT_FILE'] 'SUBMISSION_FILE', 'COMMENT_EDIT_FILE']
AGENT = """\ AGENT = """\
@@ -17,6 +17,11 @@ Authenticating is required to vote and leave comments. If only a username is
given, the program will display a secure prompt to enter a password. given, the program will display a secure prompt to enter a password.
""" """
OAUTH = """\
Authentication is now done by OAuth, since PRAW will stop supporting login with
username and password soon.
"""
CONTROLS = """ CONTROLS = """
Controls Controls
-------- --------
@@ -42,7 +47,7 @@ Authenticated Commands
`c` : Compose a new post or comment `c` : Compose a new post or comment
`e` : Edit an existing post or comment `e` : Edit an existing post or comment
`d` : Delete an existing post or comment `d` : Delete an existing post or comment
`s` : Open subscribed subreddits list `s` : Open/close subscribed subreddits list
Subreddit Mode Subreddit Mode
`l` or `RIGHT` : Enter the selected submission `l` or `RIGHT` : Enter the selected submission

189
rtv/oauth.py Normal file
View File

@@ -0,0 +1,189 @@
from six.moves import configparser
import curses
import logging
import os
import time
import uuid
import webbrowser
import praw
from . import config
from .curses_helpers import show_notification, prompt_input
from tornado import ioloop, web
__all__ = ['token_validity', 'OAuthTool']
_logger = logging.getLogger(__name__)
token_validity = 3540
oauth_state = None
oauth_code = None
oauth_error = None
class HomeHandler(web.RequestHandler):
def get(self):
self.render('home.html')
class AuthHandler(web.RequestHandler):
def get(self):
global oauth_state
global oauth_code
global oauth_error
oauth_state = self.get_argument('state', default='state_placeholder')
oauth_code = self.get_argument('code', default='code_placeholder')
oauth_error = self.get_argument('error', default='error_placeholder')
self.render('auth.html', state=oauth_state, code=oauth_code, error=oauth_error)
ioloop.IOLoop.current().stop()
class OAuthTool(object):
def __init__(self, reddit, stdscr=None, loader=None,
client_id=None, redirect_uri=None, scope=None):
self.reddit = reddit
self.stdscr = stdscr
self.loader = loader
self.config = configparser.ConfigParser()
self.config_fp = None
self.client_id = client_id or config.oauth_client_id
# Comply with PRAW's desperate need for client secret
self.client_secret = config.oauth_client_secret
self.redirect_uri = redirect_uri or config.oauth_redirect_uri
self.scope = scope or config.oauth_scope.split('-')
self.access_info = {}
self.token_expiration = 0
# Initialize Tornado webapp and listen on port 65000
self.callback_app = web.Application([
(r'/', HomeHandler),
(r'/auth', AuthHandler),
], template_path='rtv/templates')
self.callback_app.listen(65000)
def get_config_fp(self):
HOME = os.path.expanduser('~')
XDG_CONFIG_HOME = os.getenv('XDG_CONFIG_HOME',
os.path.join(HOME, '.config'))
config_paths = [
os.path.join(XDG_CONFIG_HOME, 'rtv', 'rtv.cfg'),
os.path.join(HOME, '.rtv')
]
# get the first existing config file
for config_path in config_paths:
if os.path.exists(config_path):
break
return config_path
def open_config(self, update=False):
if self.config_fp is None:
self.config_fp = self.get_config_fp()
if update:
self.config.read(self.config_fp)
def save_config(self):
self.open_config()
with open(self.config_fp, 'w') as cfg:
self.config.write(cfg)
def set_token_expiration(self):
self.token_expiration = time.time() + token_validity
def token_expired(self):
return time.time() > self.token_expiration
def refresh(self, force=False):
if self.token_expired() or force:
try:
with self.loader(message='Refreshing token'):
new_access_info = self.reddit.refresh_access_information(
self.config['oauth']['refresh_token'])
self.access_info = new_access_info
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
access_token=self.access_info['access_token'],
refresh_token=self.access_info['refresh_token'])
self.set_token_expiration()
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
praw.errors.HTTPException) as e:
show_notification(self.stdscr, ['Invalid OAuth data'])
else:
self.config['oauth']['access_token'] = self.access_info['access_token']
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
self.save_config()
def authorize(self):
self.reddit.set_oauth_app_info(self.client_id,
self.client_secret,
self.redirect_uri)
self.open_config(update=True)
# If no previous OAuth data found, starting from scratch
if 'oauth' not in self.config or 'access_token' not in self.config['oauth']:
# Generate a random UUID
hex_uuid = uuid.uuid4().hex
permission_ask_page_link = self.reddit.get_authorize_url(str(hex_uuid),
scope=self.scope, refreshable=True)
with self.loader(message='Waiting for authorization'):
webbrowser.open(permission_ask_page_link)
ioloop.IOLoop.current().start()
global oauth_state
global oauth_code
global oauth_error
self.final_state = oauth_state
self.final_code = oauth_code
self.final_error = oauth_error
# Check if access was denied
if self.final_error == 'access_denied':
show_notification(self.stdscr, ['Declined access'])
return
elif self.final_error != 'error_placeholder':
show_notification(self.stdscr, ['Authentication error'])
return
# Check if UUID matches obtained state
# (if not, authorization process is compromised, and I'm giving up)
if hex_uuid != self.final_state:
show_notification(self.stdscr, ['UUID mismatch, stopping.'])
return
try:
with self.loader(message='Logging in'):
# Get access information (tokens and scopes)
self.access_info = self.reddit.get_access_information(self.final_code)
self.reddit.set_access_credentials(
scope=set(self.access_info['scope']),
access_token=self.access_info['access_token'],
refresh_token=self.access_info['refresh_token'])
self.set_token_expiration()
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
show_notification(self.stdscr, ['Invalid OAuth data'])
else:
if 'oauth' not in self.config:
self.config['oauth'] = {}
self.config['oauth']['access_token'] = self.access_info['access_token']
self.config['oauth']['refresh_token'] = self.access_info['refresh_token']
self.save_config()
# Otherwise, fetch new access token
else:
self.refresh(force=True)

View File

@@ -12,6 +12,7 @@ from .helpers import open_editor
from .curses_helpers import (Color, show_notification, show_help, prompt_input, from .curses_helpers import (Color, show_notification, show_help, prompt_input,
add_line) add_line)
from .docs import COMMENT_EDIT_FILE, SUBMISSION_FILE from .docs import COMMENT_EDIT_FILE, SUBMISSION_FILE
from .oauth import OAuthTool
__all__ = ['Navigator', 'BaseController', 'BasePage'] __all__ = ['Navigator', 'BaseController', 'BasePage']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@@ -244,11 +245,12 @@ class BasePage(object):
MIN_HEIGHT = 10 MIN_HEIGHT = 10
MIN_WIDTH = 20 MIN_WIDTH = 20
def __init__(self, stdscr, reddit, content, **kwargs): def __init__(self, stdscr, reddit, content, oauth, **kwargs):
self.stdscr = stdscr self.stdscr = stdscr
self.reddit = reddit self.reddit = reddit
self.content = content self.content = content
self.oauth = oauth
self.nav = Navigator(self.content.get, **kwargs) self.nav = Navigator(self.content.get, **kwargs)
self._header_window = None self._header_window = None
@@ -312,6 +314,9 @@ class BasePage(object):
@BaseController.register('a') @BaseController.register('a')
def upvote(self): def upvote(self):
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
try: try:
if 'likes' not in data: if 'likes' not in data:
@@ -327,6 +332,9 @@ class BasePage(object):
@BaseController.register('z') @BaseController.register('z')
def downvote(self): def downvote(self):
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
try: try:
if 'likes' not in data: if 'likes' not in data:
@@ -348,23 +356,11 @@ class BasePage(object):
account. account.
""" """
if self.reddit.is_logged_in(): if self.reddit.is_oauth_session():
self.logout() self.reddit.clear_authentication()
return return
username = prompt_input(self.stdscr, 'Enter username:') self.oauth.authorize()
password = prompt_input(self.stdscr, 'Enter password:', hide=True)
if not username or not password:
curses.flash()
return
try:
with self.loader(message='Logging in'):
self.reddit.login(username, password)
except praw.errors.InvalidUserPass:
show_notification(self.stdscr, ['Invalid user/pass'])
else:
show_notification(self.stdscr, ['Welcome {}'.format(username)])
@BaseController.register('d') @BaseController.register('d')
def delete(self): def delete(self):
@@ -372,10 +368,13 @@ class BasePage(object):
Delete a submission or comment. Delete a submission or comment.
""" """
if not self.reddit.is_logged_in(): if not self.reddit.is_oauth_session():
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
if data.get('author') != self.reddit.user.name: if data.get('author') != self.reddit.user.name:
curses.flash() curses.flash()
@@ -400,10 +399,13 @@ class BasePage(object):
Edit a submission or comment. Edit a submission or comment.
""" """
if not self.reddit.is_logged_in(): if not self.reddit.is_oauth_session():
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
if data.get('author') != self.reddit.user.name: if data.get('author') != self.reddit.user.name:
curses.flash() curses.flash()
@@ -432,27 +434,15 @@ class BasePage(object):
s.catch = False s.catch = False
self.refresh_content() self.refresh_content()
@BaseController.register('s')
def get_subscriptions(self):
"""
Displays subscribed subreddits
"""
if not self.reddit.is_logged_in():
show_notification(self.stdscr, ['Not logged in'])
return
data = self.content.get(self.nav.absolute_index)
with self.safe_call as s:
subscriptions = SubscriptionPage(self.stdscr, self.reddit)
subscriptions.loop()
self.refresh_content()
@BaseController.register('i') @BaseController.register('i')
def get_inbox(self): def get_inbox(self):
""" """
Checks the inbox for unread messages and displays a notification. Checks the inbox for unread messages and displays a notification.
""" """
# Refresh access token if expired
self.oauth.refresh()
inbox = len(list(self.reddit.get_unread(limit=1))) inbox = len(list(self.reddit.get_unread(limit=1)))
try: try:
if inbox > 0: if inbox > 0:

View File

@@ -20,10 +20,11 @@ class SubmissionController(BaseController):
class SubmissionPage(BasePage): class SubmissionPage(BasePage):
def __init__(self, stdscr, reddit, url=None, submission=None): def __init__(self, stdscr, reddit, oauth, url=None, submission=None):
self.controller = SubmissionController(self) self.controller = SubmissionController(self)
self.loader = LoadScreen(stdscr) self.loader = LoadScreen(stdscr)
self.oauth = oauth
if url: if url:
content = SubmissionContent.from_url(reddit, url, self.loader) content = SubmissionContent.from_url(reddit, url, self.loader)
elif submission: elif submission:
@@ -32,7 +33,7 @@ class SubmissionPage(BasePage):
raise ValueError('Must specify url or submission') raise ValueError('Must specify url or submission')
super(SubmissionPage, self).__init__(stdscr, reddit, super(SubmissionPage, self).__init__(stdscr, reddit,
content, page_index=-1) content, oauth, page_index=-1)
def loop(self): def loop(self):
"Main control loop" "Main control loop"
@@ -88,10 +89,13 @@ class SubmissionPage(BasePage):
selected comment. selected comment.
""" """
if not self.reddit.is_logged_in(): if not self.reddit.is_oauth_session():
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
if data['type'] == 'Submission': if data['type'] == 'Submission':
content = data['text'] content = data['text']
@@ -127,6 +131,9 @@ class SubmissionPage(BasePage):
def delete_comment(self): def delete_comment(self):
"Delete a comment as long as it is not the current submission" "Delete a comment as long as it is not the current submission"
# Refresh access token if expired
self.oauth.refresh()
if self.nav.absolute_index != -1: if self.nav.absolute_index != -1:
self.delete() self.delete()
else: else:

View File

@@ -33,13 +33,14 @@ class SubredditController(BaseController):
class SubredditPage(BasePage): class SubredditPage(BasePage):
def __init__(self, stdscr, reddit, name): def __init__(self, stdscr, reddit, oauth, name):
self.controller = SubredditController(self) self.controller = SubredditController(self)
self.loader = LoadScreen(stdscr) self.loader = LoadScreen(stdscr)
self.oauth = oauth
content = SubredditContent.from_name(reddit, name, self.loader) content = SubredditContent.from_name(reddit, name, self.loader)
super(SubredditPage, self).__init__(stdscr, reddit, content) super(SubredditPage, self).__init__(stdscr, reddit, content, oauth)
def loop(self): def loop(self):
"Main control loop" "Main control loop"
@@ -53,6 +54,9 @@ class SubredditPage(BasePage):
def refresh_content(self, name=None, order=None): def refresh_content(self, name=None, order=None):
"Re-download all submissions and reset the page index" "Re-download all submissions and reset the page index"
# Refresh access token if expired
self.oauth.refresh()
name = name or self.content.name name = name or self.content.name
order = order or self.content.order order = order or self.content.order
@@ -104,10 +108,9 @@ class SubredditPage(BasePage):
"Select the current submission to view posts" "Select the current submission to view posts"
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
page = SubmissionPage(self.stdscr, self.reddit, url=data['permalink']) page = SubmissionPage(self.stdscr, self.reddit, self.oauth, url=data['permalink'])
page.loop() page.loop()
if data['url_type'] == 'selfpost':
if data['url'] == 'selfpost':
global history global history
history.add(data['url_full']) history.add(data['url_full'])
@@ -117,22 +120,25 @@ class SubredditPage(BasePage):
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
url = data['url_full'] url = data['url_full']
global history
history.add(url)
if data['url_type'] in ['x-post', 'selfpost']: if data['url_type'] in ['x-post', 'selfpost']:
page = SubmissionPage(self.stdscr, self.reddit, url=url) page = SubmissionPage(self.stdscr, self.reddit, self.oauth, url=url)
page.loop() page.loop()
else: else:
open_browser(url) open_browser(url)
global history
history.add(url)
@SubredditController.register('c') @SubredditController.register('c')
def post_submission(self): def post_submission(self):
"Post a new submission to the given subreddit" "Post a new submission to the given subreddit"
if not self.reddit.is_logged_in(): if not self.reddit.is_oauth_session():
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
# Strips the subreddit to just the name # Strips the subreddit to just the name
# Make sure it is a valid subreddit for submission # Make sure it is a valid subreddit for submission
subreddit = self.reddit.get_subreddit(self.content.name) subreddit = self.reddit.get_subreddit(self.content.name)
@@ -162,7 +168,7 @@ class SubredditPage(BasePage):
time.sleep(2.0) time.sleep(2.0)
# Open the newly created post # Open the newly created post
s.catch = False s.catch = False
page = SubmissionPage(self.stdscr, self.reddit, submission=post) page = SubmissionPage(self.stdscr, self.reddit, self.oauth, submission=post)
page.loop() page.loop()
self.refresh_content() self.refresh_content()
@@ -170,13 +176,22 @@ class SubredditPage(BasePage):
def open_subscriptions(self): def open_subscriptions(self):
"Open user subscriptions page" "Open user subscriptions page"
if not self.reddit.is_logged_in(): if not self.reddit.is_oauth_session():
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
page = SubscriptionPage(self.stdscr, self.reddit) # Refresh access token if expired
self.oauth.refresh()
# Open subscriptions page
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
page.loop() page.loop()
# When user has chosen a subreddit in the subscriptions list,
# refresh content with the selected subreddit
if page.selected_subreddit_data is not None:
self.refresh_content(name=page.selected_subreddit_data['name'])
@staticmethod @staticmethod
def draw_item(win, data, inverted=False): def draw_item(win, data, inverted=False):

View File

@@ -14,12 +14,16 @@ class SubscriptionController(BaseController):
character_map = {} character_map = {}
class SubscriptionPage(BasePage): class SubscriptionPage(BasePage):
def __init__(self, stdscr, reddit):
def __init__(self, stdscr, reddit, oauth):
self.controller = SubscriptionController(self) self.controller = SubscriptionController(self)
self.loader = LoadScreen(stdscr) self.loader = LoadScreen(stdscr)
self.oauth = oauth
self.selected_subreddit_data = None
content = SubscriptionContent.get_list(reddit, self.loader) content = SubscriptionContent.from_user(reddit, self.loader)
super(SubscriptionPage, self).__init__(stdscr, reddit, content) super(SubscriptionPage, self).__init__(stdscr, reddit, content, oauth)
def loop(self): def loop(self):
"Main control loop" "Main control loop"
@@ -34,19 +38,20 @@ class SubscriptionPage(BasePage):
def refresh_content(self): def refresh_content(self):
"Re-download all subscriptions and reset the page index" "Re-download all subscriptions and reset the page index"
self.content = SubscriptionContent.get_list(self.reddit, self.loader) # Refresh access token if expired
self.oauth.refresh()
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
self.nav = Navigator(self.content.get) self.nav = Navigator(self.content.get)
@SubscriptionController.register(curses.KEY_ENTER, 10) @SubscriptionController.register(curses.KEY_ENTER, 10, curses.KEY_RIGHT)
def open_selected_subreddit(self): def store_selected_subreddit(self):
"Open the selected subreddit" "Store the selected subreddit and return to the subreddit page"
from .subreddit import SubredditPage self.selected_subreddit_data = self.content.get(self.nav.absolute_index)
data = self.content.get(self.nav.absolute_index) self.active = False
page = SubredditPage(self.stdscr, self.reddit, data['name'][2:]) # Strip the leading /r
page.loop()
@SubscriptionController.register(curses.KEY_LEFT) @SubscriptionController.register(curses.KEY_LEFT, 'h', 's')
def close_subscriptions(self): def close_subscriptions(self):
"Close subscriptions and return to the subreddit page" "Close subscriptions and return to the subreddit page"
@@ -61,12 +66,12 @@ class SubscriptionPage(BasePage):
valid_rows = range(0, n_rows) valid_rows = range(0, n_rows)
offset = 0 if not inverted else -(data['n_rows'] - n_rows) offset = 0 if not inverted else -(data['n_rows'] - n_rows)
n_title = len(data['split_title']) row = offset
for row, text in enumerate(data['split_title'], start=offset):
if row in valid_rows:
attr = curses.A_BOLD | Color.YELLOW
add_line(win, u'{name}'.format(**data), row, 1, attr)
row = n_title + offset
if row in valid_rows: if row in valid_rows:
add_line(win, u'{title}'.format(**data), row, 1) attr = curses.A_BOLD | Color.YELLOW
add_line(win, u'{name}'.format(**data), row, 1, attr)
row = offset + 1
for row, text in enumerate(data['split_title'], start=row):
if row in valid_rows:
add_line(win, text, row, 1)

15
rtv/templates/auth.html Normal file
View File

@@ -0,0 +1,15 @@
<!DOCTYPE html>
<title>RTV OAuth</title>
{% if error == 'access_denied' %}
<h3 style="color: red">Declined rtv access</h3>
<p>You chose to stop <span style="font-weight: bold">Reddit Terminal Viewer</span> from accessing your account, it will continue in unauthenticated mode.<br>
You can close this page.</p>
{% elif error != 'error_placeholder' %}
<h3 style="color: red">Error : {{ error }}</h3>
{% elif (state == 'state_placeholder' or code == 'code_placeholder') and error == 'error_placeholder' %}
<h3>Wait...</h3>
<p>This page is supposed to be a Reddit OAuth callback. You can't just come here hands in the pocket!</p>
{% else %}
<h3 style="color: green">Allowed rtv access</h3>
<p><span style="font-weight: bold">Reddit Terminal Viewer</span> will now log in. You can close this page.</p>
{% end %}

3
rtv/templates/home.html Normal file
View File

@@ -0,0 +1,3 @@
<!DOCTYPE html>
<title>OAuth helper</title>
<h1>Reddit Terminal Viewer OAuth helper</h1>

View File

@@ -13,7 +13,7 @@ setup(
keywords='reddit terminal praw curses', keywords='reddit terminal praw curses',
packages=['rtv'], packages=['rtv'],
include_package_data=True, include_package_data=True,
install_requires=['praw>=3.1.0', 'six', 'requests', 'kitchen'], install_requires=['tornado', 'praw>=3.1.0', 'six', 'requests', 'kitchen'],
entry_points={'console_scripts': ['rtv=rtv.__main__:main']}, entry_points={'console_scripts': ['rtv=rtv.__main__:main']},
classifiers=[ classifiers=[
'Intended Audience :: End Users/Desktop', 'Intended Audience :: End Users/Desktop',