Let PRAW manage authentication

This commit is contained in:
Théo Piboubès
2015-09-01 22:32:56 +02:00
parent 314d2dbf26
commit f6546aaf75
6 changed files with 4 additions and 73 deletions

View File

@@ -103,7 +103,6 @@ def command_line():
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
oauth_group.add_argument('--auth-token', dest='access_token', help='OAuth authorization token')
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
args = parser.parse_args()

View File

@@ -16,8 +16,6 @@ from tornado import ioloop, web
__all__ = ['token_validity', 'OAuthTool']
_logger = logging.getLogger(__name__)
token_validity = 3540
oauth_state = None
oauth_code = None
oauth_error = None
@@ -62,8 +60,6 @@ class OAuthTool(object):
self.access_info = {}
self.token_expiration = 0
# Initialize Tornado webapp and listen on port 65000
self.callback_app = web.Application([
(r'/', HomeHandler),
@@ -95,31 +91,6 @@ class OAuthTool(object):
with open(self.config_fp, 'w') as cfg:
self.config.write(cfg)
def set_token_expiration(self):
self.token_expiration = time.time() + token_validity
def token_expired(self):
return time.time() > self.token_expiration
def refresh(self, force=False):
if self.token_expired() or force:
try:
with self.loader(message='Refreshing token'):
new_access_info = self.reddit.refresh_access_information(
self.config.get('oauth', 'refresh_token'))
self.access_info = new_access_info
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
access_token=self.access_info['access_token'],
refresh_token=self.access_info['refresh_token'])
self.set_token_expiration()
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
praw.errors.HTTPException) as e:
show_notification(self.stdscr, ['Invalid OAuth data'])
else:
self.config.set('oauth', 'access_token', self.access_info['access_token'])
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
self.save_config()
def authorize(self):
self.reddit.set_oauth_app_info(self.client_id,
self.client_secret,
@@ -127,7 +98,7 @@ class OAuthTool(object):
self.open_config(update=True)
# If no previous OAuth data found, starting from scratch
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'access_token'):
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'refresh_token'):
# Generate a random UUID
hex_uuid = uuid.uuid4().hex
@@ -164,21 +135,15 @@ class OAuthTool(object):
with self.loader(message='Logging in'):
# Get access information (tokens and scopes)
self.access_info = self.reddit.get_access_information(self.final_code)
self.reddit.set_access_credentials(
scope=set(self.access_info['scope']),
access_token=self.access_info['access_token'],
refresh_token=self.access_info['refresh_token'])
self.set_token_expiration()
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
show_notification(self.stdscr, ['Invalid OAuth data'])
else:
if not self.config.has_section('oauth'):
self.config.add_section('oauth')
self.config.set('oauth', 'access_token', self.access_info['access_token'])
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
self.save_config()
# Otherwise, fetch new access token
else:
self.refresh(force=True)
with self.loader(message='Logging in'):
self.reddit.refresh_access_information(self.config.get('oauth', 'refresh_token'))

View File

@@ -314,9 +314,6 @@ class BasePage(object):
@BaseController.register('a')
def upvote(self):
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index)
try:
if 'likes' not in data:
@@ -332,9 +329,6 @@ class BasePage(object):
@BaseController.register('z')
def downvote(self):
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index)
try:
if 'likes' not in data:
@@ -372,9 +366,6 @@ class BasePage(object):
show_notification(self.stdscr, ['Not logged in'])
return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index)
if data.get('author') != self.reddit.user.name:
curses.flash()
@@ -403,9 +394,6 @@ class BasePage(object):
show_notification(self.stdscr, ['Not logged in'])
return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index)
if data.get('author') != self.reddit.user.name:
curses.flash()
@@ -440,9 +428,6 @@ class BasePage(object):
Checks the inbox for unread messages and displays a notification.
"""
# Refresh access token if expired
self.oauth.refresh()
inbox = len(list(self.reddit.get_unread(limit=1)))
try:
if inbox > 0:

View File

@@ -93,9 +93,6 @@ class SubmissionPage(BasePage):
show_notification(self.stdscr, ['Not logged in'])
return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index)
if data['type'] == 'Submission':
content = data['text']
@@ -131,9 +128,6 @@ class SubmissionPage(BasePage):
def delete_comment(self):
"Delete a comment as long as it is not the current submission"
# Refresh access token if expired
self.oauth.refresh()
if self.nav.absolute_index != -1:
self.delete()
else:

View File

@@ -8,7 +8,7 @@ import requests
from .exceptions import SubredditError, AccountError
from .page import BasePage, Navigator, BaseController
from .submission import SubmissionPage
from .subscriptions import SubscriptionPage
from .subscription import SubscriptionPage
from .content import SubredditContent
from .helpers import open_browser, open_editor, strip_subreddit_url
from .docs import SUBMISSION_FILE
@@ -54,9 +54,6 @@ class SubredditPage(BasePage):
def refresh_content(self, name=None, order=None):
"Re-download all submissions and reset the page index"
# Refresh access token if expired
self.oauth.refresh()
name = name or self.content.name
order = order or self.content.order
@@ -136,9 +133,6 @@ class SubredditPage(BasePage):
show_notification(self.stdscr, ['Not logged in'])
return
# Refresh access token if expired
self.oauth.refresh()
# Strips the subreddit to just the name
# Make sure it is a valid subreddit for submission
subreddit = self.reddit.get_subreddit(self.content.name)
@@ -180,9 +174,6 @@ class SubredditPage(BasePage):
show_notification(self.stdscr, ['Not logged in'])
return
# Refresh access token if expired
self.oauth.refresh()
# Open subscriptions page
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
page.loop()

View File

@@ -38,9 +38,6 @@ class SubscriptionPage(BasePage):
def refresh_content(self):
"Re-download all subscriptions and reset the page index"
# Refresh access token if expired
self.oauth.refresh()
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
self.nav = Navigator(self.content.get)