Let PRAW manage authentication
This commit is contained in:
@@ -103,7 +103,6 @@ def command_line():
|
||||
|
||||
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
|
||||
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
|
||||
oauth_group.add_argument('--auth-token', dest='access_token', help='OAuth authorization token')
|
||||
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
41
rtv/oauth.py
41
rtv/oauth.py
@@ -16,8 +16,6 @@ from tornado import ioloop, web
|
||||
__all__ = ['token_validity', 'OAuthTool']
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
token_validity = 3540
|
||||
|
||||
oauth_state = None
|
||||
oauth_code = None
|
||||
oauth_error = None
|
||||
@@ -62,8 +60,6 @@ class OAuthTool(object):
|
||||
|
||||
self.access_info = {}
|
||||
|
||||
self.token_expiration = 0
|
||||
|
||||
# Initialize Tornado webapp and listen on port 65000
|
||||
self.callback_app = web.Application([
|
||||
(r'/', HomeHandler),
|
||||
@@ -95,31 +91,6 @@ class OAuthTool(object):
|
||||
with open(self.config_fp, 'w') as cfg:
|
||||
self.config.write(cfg)
|
||||
|
||||
def set_token_expiration(self):
|
||||
self.token_expiration = time.time() + token_validity
|
||||
|
||||
def token_expired(self):
|
||||
return time.time() > self.token_expiration
|
||||
|
||||
def refresh(self, force=False):
|
||||
if self.token_expired() or force:
|
||||
try:
|
||||
with self.loader(message='Refreshing token'):
|
||||
new_access_info = self.reddit.refresh_access_information(
|
||||
self.config.get('oauth', 'refresh_token'))
|
||||
self.access_info = new_access_info
|
||||
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
|
||||
praw.errors.HTTPException) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
self.config.set('oauth', 'access_token', self.access_info['access_token'])
|
||||
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
|
||||
self.save_config()
|
||||
|
||||
def authorize(self):
|
||||
self.reddit.set_oauth_app_info(self.client_id,
|
||||
self.client_secret,
|
||||
@@ -127,7 +98,7 @@ class OAuthTool(object):
|
||||
|
||||
self.open_config(update=True)
|
||||
# If no previous OAuth data found, starting from scratch
|
||||
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'access_token'):
|
||||
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'refresh_token'):
|
||||
# Generate a random UUID
|
||||
hex_uuid = uuid.uuid4().hex
|
||||
|
||||
@@ -164,21 +135,15 @@ class OAuthTool(object):
|
||||
with self.loader(message='Logging in'):
|
||||
# Get access information (tokens and scopes)
|
||||
self.access_info = self.reddit.get_access_information(self.final_code)
|
||||
|
||||
self.reddit.set_access_credentials(
|
||||
scope=set(self.access_info['scope']),
|
||||
access_token=self.access_info['access_token'],
|
||||
refresh_token=self.access_info['refresh_token'])
|
||||
self.set_token_expiration()
|
||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||
else:
|
||||
if not self.config.has_section('oauth'):
|
||||
self.config.add_section('oauth')
|
||||
|
||||
self.config.set('oauth', 'access_token', self.access_info['access_token'])
|
||||
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
|
||||
self.save_config()
|
||||
# Otherwise, fetch new access token
|
||||
else:
|
||||
self.refresh(force=True)
|
||||
with self.loader(message='Logging in'):
|
||||
self.reddit.refresh_access_information(self.config.get('oauth', 'refresh_token'))
|
||||
|
||||
15
rtv/page.py
15
rtv/page.py
@@ -314,9 +314,6 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('a')
|
||||
def upvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -332,9 +329,6 @@ class BasePage(object):
|
||||
|
||||
@BaseController.register('z')
|
||||
def downvote(self):
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
try:
|
||||
if 'likes' not in data:
|
||||
@@ -372,9 +366,6 @@ class BasePage(object):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -403,9 +394,6 @@ class BasePage(object):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data.get('author') != self.reddit.user.name:
|
||||
curses.flash()
|
||||
@@ -440,9 +428,6 @@ class BasePage(object):
|
||||
Checks the inbox for unread messages and displays a notification.
|
||||
"""
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
inbox = len(list(self.reddit.get_unread(limit=1)))
|
||||
try:
|
||||
if inbox > 0:
|
||||
|
||||
@@ -93,9 +93,6 @@ class SubmissionPage(BasePage):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
data = self.content.get(self.nav.absolute_index)
|
||||
if data['type'] == 'Submission':
|
||||
content = data['text']
|
||||
@@ -131,9 +128,6 @@ class SubmissionPage(BasePage):
|
||||
def delete_comment(self):
|
||||
"Delete a comment as long as it is not the current submission"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
if self.nav.absolute_index != -1:
|
||||
self.delete()
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,7 @@ import requests
|
||||
from .exceptions import SubredditError, AccountError
|
||||
from .page import BasePage, Navigator, BaseController
|
||||
from .submission import SubmissionPage
|
||||
from .subscriptions import SubscriptionPage
|
||||
from .subscription import SubscriptionPage
|
||||
from .content import SubredditContent
|
||||
from .helpers import open_browser, open_editor, strip_subreddit_url
|
||||
from .docs import SUBMISSION_FILE
|
||||
@@ -54,9 +54,6 @@ class SubredditPage(BasePage):
|
||||
def refresh_content(self, name=None, order=None):
|
||||
"Re-download all submissions and reset the page index"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
name = name or self.content.name
|
||||
order = order or self.content.order
|
||||
|
||||
@@ -136,9 +133,6 @@ class SubredditPage(BasePage):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
# Strips the subreddit to just the name
|
||||
# Make sure it is a valid subreddit for submission
|
||||
subreddit = self.reddit.get_subreddit(self.content.name)
|
||||
@@ -180,9 +174,6 @@ class SubredditPage(BasePage):
|
||||
show_notification(self.stdscr, ['Not logged in'])
|
||||
return
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
# Open subscriptions page
|
||||
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
|
||||
page.loop()
|
||||
|
||||
@@ -38,9 +38,6 @@ class SubscriptionPage(BasePage):
|
||||
def refresh_content(self):
|
||||
"Re-download all subscriptions and reset the page index"
|
||||
|
||||
# Refresh access token if expired
|
||||
self.oauth.refresh()
|
||||
|
||||
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
|
||||
self.nav = Navigator(self.content.get)
|
||||
|
||||
Reference in New Issue
Block a user