Let PRAW manage authentication

This commit is contained in:
Théo Piboubès
2015-09-01 22:32:56 +02:00
parent 314d2dbf26
commit f6546aaf75
6 changed files with 4 additions and 73 deletions

View File

@@ -103,7 +103,6 @@ def command_line():
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH) oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting') oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
oauth_group.add_argument('--auth-token', dest='access_token', help='OAuth authorization token')
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token') oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
args = parser.parse_args() args = parser.parse_args()

View File

@@ -16,8 +16,6 @@ from tornado import ioloop, web
__all__ = ['token_validity', 'OAuthTool'] __all__ = ['token_validity', 'OAuthTool']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
token_validity = 3540
oauth_state = None oauth_state = None
oauth_code = None oauth_code = None
oauth_error = None oauth_error = None
@@ -62,8 +60,6 @@ class OAuthTool(object):
self.access_info = {} self.access_info = {}
self.token_expiration = 0
# Initialize Tornado webapp and listen on port 65000 # Initialize Tornado webapp and listen on port 65000
self.callback_app = web.Application([ self.callback_app = web.Application([
(r'/', HomeHandler), (r'/', HomeHandler),
@@ -95,31 +91,6 @@ class OAuthTool(object):
with open(self.config_fp, 'w') as cfg: with open(self.config_fp, 'w') as cfg:
self.config.write(cfg) self.config.write(cfg)
def set_token_expiration(self):
self.token_expiration = time.time() + token_validity
def token_expired(self):
return time.time() > self.token_expiration
def refresh(self, force=False):
if self.token_expired() or force:
try:
with self.loader(message='Refreshing token'):
new_access_info = self.reddit.refresh_access_information(
self.config.get('oauth', 'refresh_token'))
self.access_info = new_access_info
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
access_token=self.access_info['access_token'],
refresh_token=self.access_info['refresh_token'])
self.set_token_expiration()
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
praw.errors.HTTPException) as e:
show_notification(self.stdscr, ['Invalid OAuth data'])
else:
self.config.set('oauth', 'access_token', self.access_info['access_token'])
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
self.save_config()
def authorize(self): def authorize(self):
self.reddit.set_oauth_app_info(self.client_id, self.reddit.set_oauth_app_info(self.client_id,
self.client_secret, self.client_secret,
@@ -127,7 +98,7 @@ class OAuthTool(object):
self.open_config(update=True) self.open_config(update=True)
# If no previous OAuth data found, starting from scratch # If no previous OAuth data found, starting from scratch
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'access_token'): if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'refresh_token'):
# Generate a random UUID # Generate a random UUID
hex_uuid = uuid.uuid4().hex hex_uuid = uuid.uuid4().hex
@@ -164,21 +135,15 @@ class OAuthTool(object):
with self.loader(message='Logging in'): with self.loader(message='Logging in'):
# Get access information (tokens and scopes) # Get access information (tokens and scopes)
self.access_info = self.reddit.get_access_information(self.final_code) self.access_info = self.reddit.get_access_information(self.final_code)
self.reddit.set_access_credentials(
scope=set(self.access_info['scope']),
access_token=self.access_info['access_token'],
refresh_token=self.access_info['refresh_token'])
self.set_token_expiration()
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e: except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
show_notification(self.stdscr, ['Invalid OAuth data']) show_notification(self.stdscr, ['Invalid OAuth data'])
else: else:
if not self.config.has_section('oauth'): if not self.config.has_section('oauth'):
self.config.add_section('oauth') self.config.add_section('oauth')
self.config.set('oauth', 'access_token', self.access_info['access_token'])
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token']) self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
self.save_config() self.save_config()
# Otherwise, fetch new access token # Otherwise, fetch new access token
else: else:
self.refresh(force=True) with self.loader(message='Logging in'):
self.reddit.refresh_access_information(self.config.get('oauth', 'refresh_token'))

View File

@@ -314,9 +314,6 @@ class BasePage(object):
@BaseController.register('a') @BaseController.register('a')
def upvote(self): def upvote(self):
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
try: try:
if 'likes' not in data: if 'likes' not in data:
@@ -332,9 +329,6 @@ class BasePage(object):
@BaseController.register('z') @BaseController.register('z')
def downvote(self): def downvote(self):
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
try: try:
if 'likes' not in data: if 'likes' not in data:
@@ -372,9 +366,6 @@ class BasePage(object):
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
if data.get('author') != self.reddit.user.name: if data.get('author') != self.reddit.user.name:
curses.flash() curses.flash()
@@ -403,9 +394,6 @@ class BasePage(object):
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
if data.get('author') != self.reddit.user.name: if data.get('author') != self.reddit.user.name:
curses.flash() curses.flash()
@@ -440,9 +428,6 @@ class BasePage(object):
Checks the inbox for unread messages and displays a notification. Checks the inbox for unread messages and displays a notification.
""" """
# Refresh access token if expired
self.oauth.refresh()
inbox = len(list(self.reddit.get_unread(limit=1))) inbox = len(list(self.reddit.get_unread(limit=1)))
try: try:
if inbox > 0: if inbox > 0:

View File

@@ -93,9 +93,6 @@ class SubmissionPage(BasePage):
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
data = self.content.get(self.nav.absolute_index) data = self.content.get(self.nav.absolute_index)
if data['type'] == 'Submission': if data['type'] == 'Submission':
content = data['text'] content = data['text']
@@ -131,9 +128,6 @@ class SubmissionPage(BasePage):
def delete_comment(self): def delete_comment(self):
"Delete a comment as long as it is not the current submission" "Delete a comment as long as it is not the current submission"
# Refresh access token if expired
self.oauth.refresh()
if self.nav.absolute_index != -1: if self.nav.absolute_index != -1:
self.delete() self.delete()
else: else:

View File

@@ -8,7 +8,7 @@ import requests
from .exceptions import SubredditError, AccountError from .exceptions import SubredditError, AccountError
from .page import BasePage, Navigator, BaseController from .page import BasePage, Navigator, BaseController
from .submission import SubmissionPage from .submission import SubmissionPage
from .subscriptions import SubscriptionPage from .subscription import SubscriptionPage
from .content import SubredditContent from .content import SubredditContent
from .helpers import open_browser, open_editor, strip_subreddit_url from .helpers import open_browser, open_editor, strip_subreddit_url
from .docs import SUBMISSION_FILE from .docs import SUBMISSION_FILE
@@ -54,9 +54,6 @@ class SubredditPage(BasePage):
def refresh_content(self, name=None, order=None): def refresh_content(self, name=None, order=None):
"Re-download all submissions and reset the page index" "Re-download all submissions and reset the page index"
# Refresh access token if expired
self.oauth.refresh()
name = name or self.content.name name = name or self.content.name
order = order or self.content.order order = order or self.content.order
@@ -136,9 +133,6 @@ class SubredditPage(BasePage):
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
# Strips the subreddit to just the name # Strips the subreddit to just the name
# Make sure it is a valid subreddit for submission # Make sure it is a valid subreddit for submission
subreddit = self.reddit.get_subreddit(self.content.name) subreddit = self.reddit.get_subreddit(self.content.name)
@@ -180,9 +174,6 @@ class SubredditPage(BasePage):
show_notification(self.stdscr, ['Not logged in']) show_notification(self.stdscr, ['Not logged in'])
return return
# Refresh access token if expired
self.oauth.refresh()
# Open subscriptions page # Open subscriptions page
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth) page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
page.loop() page.loop()

View File

@@ -38,9 +38,6 @@ class SubscriptionPage(BasePage):
def refresh_content(self): def refresh_content(self):
"Re-download all subscriptions and reset the page index" "Re-download all subscriptions and reset the page index"
# Refresh access token if expired
self.oauth.refresh()
self.content = SubscriptionContent.from_user(self.reddit, self.loader) self.content = SubscriptionContent.from_user(self.reddit, self.loader)
self.nav = Navigator(self.content.get) self.nav = Navigator(self.content.get)