Refactoring and making rtv OAuth-compliant
This commit is contained in:
42
rtv/page.py
42
rtv/page.py
@@ -12,6 +12,7 @@ from .helpers import open_editor
|
||||
from .curses_helpers import (Color, show_notification, show_help, prompt_input,
|
||||
add_line)
|
||||
from .docs import COMMENT_EDIT_FILE, SUBMISSION_FILE
|
||||
from .oauth import OAuthTool
|
||||
|
||||
__all__ = ['Navigator', 'BaseController', 'BasePage']
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -244,11 +245,12 @@ class BasePage(object):
|
||||
MIN_HEIGHT = 10
|
||||
MIN_WIDTH = 20
|
||||
|
||||
def __init__(self, stdscr, reddit, content, **kwargs):
|
||||
def __init__(self, stdscr, reddit, content, oauth, **kwargs):
|
||||
|
||||
self.stdscr = stdscr
|
||||
self.reddit = reddit
|
||||
self.content = content
|
||||
self.oauth = oauth
|
||||
self.nav = Navigator(self.content.get, **kwargs)
|
||||
|
||||
self._header_window = None
|
||||
@@ -312,6 +314,9 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('a')
|
||||
def upvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -327,6 +332,9 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('z')
|
||||
def downvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -348,23 +356,11 @@ class BasePage(object):
|
||||
account.
|
||||
"""
|
||||
|
||||
if self.reddit.is_logged_in():
|
||||
self.logout()
|
||||
if self.reddit.is_oauth_session():
|
||||
self.reddit.clear_authentication()
|
||||
return
|
||||
|
||||
username = prompt_input(self.stdscr, 'Enter username:')
|
||||
password = prompt_input(self.stdscr, 'Enter password:', hide=True)
|
||||
if not username or not password:
|
||||
curses.flash()
|
||||
return
|
||||
|
||||
try:
|
||||
with self.loader(message='Logging in'):
|
||||
self.reddit.login(username, password)
|
||||
except praw.errors.InvalidUserPass:
|
||||
show_notification(self.stdscr, ['Invalid user/pass'])
|
||||
else:
|
||||
show_notification(self.stdscr, ['Welcome {}'.format(username)])
|
||||
self.oauth.authorize()
|
||||
|
||||
@BaseController.register('d')
|
||||
def delete(self):
|
||||
@@ -372,10 +368,13 @@ class BasePage(object):
|
||||
Delete a submission or comment.
|
||||
"""
|
||||
|
||||
if not self.reddit.is_logged_in():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -400,10 +399,13 @@ class BasePage(object):
|
||||
Edit a submission or comment.
|
||||
"""
|
||||
|
||||
if not self.reddit.is_logged_in():
|
||||
if not self.reddit.is_oauth_session():
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -437,6 +439,10 @@ class BasePage(object):
|
||||
"""
|
||||
Checks the inbox for unread messages and displays a notification.
|
||||
"""
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
inbox = len(list(self.reddit.get_unread(limit=1)))
|
||||
try:
|
||||
if inbox > 0:
|
||||
|
||||
Reference in New Issue
Block a user