Let PRAW manage authentication
This commit is contained in:
@@ -103,7 +103,6 @@ def command_line():
|
|||||||
|
|
||||||
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
|
oauth_group = parser.add_argument_group('OAuth data (optional)', OAUTH)
|
||||||
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
|
oauth_group.add_argument('--auto-login', dest='auto_login', help='OAuth auto-login setting')
|
||||||
oauth_group.add_argument('--auth-token', dest='access_token', help='OAuth authorization token')
|
|
||||||
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
|
oauth_group.add_argument('--refresh-token', dest='refresh_token', help='OAuth refresh token')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
41
rtv/oauth.py
41
rtv/oauth.py
@@ -16,8 +16,6 @@ from tornado import ioloop, web
|
|||||||
__all__ = ['token_validity', 'OAuthTool']
|
__all__ = ['token_validity', 'OAuthTool']
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
token_validity = 3540
|
|
||||||
|
|
||||||
oauth_state = None
|
oauth_state = None
|
||||||
oauth_code = None
|
oauth_code = None
|
||||||
oauth_error = None
|
oauth_error = None
|
||||||
@@ -62,8 +60,6 @@ class OAuthTool(object):
|
|||||||
|
|
||||||
self.access_info = {}
|
self.access_info = {}
|
||||||
|
|
||||||
self.token_expiration = 0
|
|
||||||
|
|
||||||
# Initialize Tornado webapp and listen on port 65000
|
# Initialize Tornado webapp and listen on port 65000
|
||||||
self.callback_app = web.Application([
|
self.callback_app = web.Application([
|
||||||
(r'/', HomeHandler),
|
(r'/', HomeHandler),
|
||||||
@@ -95,31 +91,6 @@ class OAuthTool(object):
|
|||||||
with open(self.config_fp, 'w') as cfg:
|
with open(self.config_fp, 'w') as cfg:
|
||||||
self.config.write(cfg)
|
self.config.write(cfg)
|
||||||
|
|
||||||
def set_token_expiration(self):
|
|
||||||
self.token_expiration = time.time() + token_validity
|
|
||||||
|
|
||||||
def token_expired(self):
|
|
||||||
return time.time() > self.token_expiration
|
|
||||||
|
|
||||||
def refresh(self, force=False):
|
|
||||||
if self.token_expired() or force:
|
|
||||||
try:
|
|
||||||
with self.loader(message='Refreshing token'):
|
|
||||||
new_access_info = self.reddit.refresh_access_information(
|
|
||||||
self.config.get('oauth', 'refresh_token'))
|
|
||||||
self.access_info = new_access_info
|
|
||||||
self.reddit.set_access_credentials(scope=set(self.access_info['scope']),
|
|
||||||
access_token=self.access_info['access_token'],
|
|
||||||
refresh_token=self.access_info['refresh_token'])
|
|
||||||
self.set_token_expiration()
|
|
||||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken,
|
|
||||||
praw.errors.HTTPException) as e:
|
|
||||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
|
||||||
else:
|
|
||||||
self.config.set('oauth', 'access_token', self.access_info['access_token'])
|
|
||||||
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
|
|
||||||
self.save_config()
|
|
||||||
|
|
||||||
def authorize(self):
|
def authorize(self):
|
||||||
self.reddit.set_oauth_app_info(self.client_id,
|
self.reddit.set_oauth_app_info(self.client_id,
|
||||||
self.client_secret,
|
self.client_secret,
|
||||||
@@ -127,7 +98,7 @@ class OAuthTool(object):
|
|||||||
|
|
||||||
self.open_config(update=True)
|
self.open_config(update=True)
|
||||||
# If no previous OAuth data found, starting from scratch
|
# If no previous OAuth data found, starting from scratch
|
||||||
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'access_token'):
|
if not self.config.has_section('oauth') or not self.config.has_option('oauth', 'refresh_token'):
|
||||||
# Generate a random UUID
|
# Generate a random UUID
|
||||||
hex_uuid = uuid.uuid4().hex
|
hex_uuid = uuid.uuid4().hex
|
||||||
|
|
||||||
@@ -164,21 +135,15 @@ class OAuthTool(object):
|
|||||||
with self.loader(message='Logging in'):
|
with self.loader(message='Logging in'):
|
||||||
# Get access information (tokens and scopes)
|
# Get access information (tokens and scopes)
|
||||||
self.access_info = self.reddit.get_access_information(self.final_code)
|
self.access_info = self.reddit.get_access_information(self.final_code)
|
||||||
|
|
||||||
self.reddit.set_access_credentials(
|
|
||||||
scope=set(self.access_info['scope']),
|
|
||||||
access_token=self.access_info['access_token'],
|
|
||||||
refresh_token=self.access_info['refresh_token'])
|
|
||||||
self.set_token_expiration()
|
|
||||||
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
except (praw.errors.OAuthAppRequired, praw.errors.OAuthInvalidToken) as e:
|
||||||
show_notification(self.stdscr, ['Invalid OAuth data'])
|
show_notification(self.stdscr, ['Invalid OAuth data'])
|
||||||
else:
|
else:
|
||||||
if not self.config.has_section('oauth'):
|
if not self.config.has_section('oauth'):
|
||||||
self.config.add_section('oauth')
|
self.config.add_section('oauth')
|
||||||
|
|
||||||
self.config.set('oauth', 'access_token', self.access_info['access_token'])
|
|
||||||
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
|
self.config.set('oauth', 'refresh_token', self.access_info['refresh_token'])
|
||||||
self.save_config()
|
self.save_config()
|
||||||
# Otherwise, fetch new access token
|
# Otherwise, fetch new access token
|
||||||
else:
|
else:
|
||||||
self.refresh(force=True)
|
with self.loader(message='Logging in'):
|
||||||
|
self.reddit.refresh_access_information(self.config.get('oauth', 'refresh_token'))
|
||||||
|
|||||||
15
rtv/page.py
15
rtv/page.py
@@ -314,9 +314,6 @@ class BasePage(object):
|
|||||||
|
|
||||||
@BaseController.register('a')
|
@BaseController.register('a')
|
||||||
def upvote(self):
|
def upvote(self):
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
data = self.content.get(self.nav.absolute_index)
|
data = self.content.get(self.nav.absolute_index)
|
||||||
try:
|
try:
|
||||||
if 'likes' not in data:
|
if 'likes' not in data:
|
||||||
@@ -332,9 +329,6 @@ class BasePage(object):
|
|||||||
|
|
||||||
@BaseController.register('z')
|
@BaseController.register('z')
|
||||||
def downvote(self):
|
def downvote(self):
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
data = self.content.get(self.nav.absolute_index)
|
data = self.content.get(self.nav.absolute_index)
|
||||||
try:
|
try:
|
||||||
if 'likes' not in data:
|
if 'likes' not in data:
|
||||||
@@ -372,9 +366,6 @@ class BasePage(object):
|
|||||||
show_notification(self.stdscr, ['Not logged in'])
|
show_notification(self.stdscr, ['Not logged in'])
|
||||||
return
|
return
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
data = self.content.get(self.nav.absolute_index)
|
data = self.content.get(self.nav.absolute_index)
|
||||||
if data.get('author') != self.reddit.user.name:
|
if data.get('author') != self.reddit.user.name:
|
||||||
curses.flash()
|
curses.flash()
|
||||||
@@ -403,9 +394,6 @@ class BasePage(object):
|
|||||||
show_notification(self.stdscr, ['Not logged in'])
|
show_notification(self.stdscr, ['Not logged in'])
|
||||||
return
|
return
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
data = self.content.get(self.nav.absolute_index)
|
data = self.content.get(self.nav.absolute_index)
|
||||||
if data.get('author') != self.reddit.user.name:
|
if data.get('author') != self.reddit.user.name:
|
||||||
curses.flash()
|
curses.flash()
|
||||||
@@ -440,9 +428,6 @@ class BasePage(object):
|
|||||||
Checks the inbox for unread messages and displays a notification.
|
Checks the inbox for unread messages and displays a notification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
inbox = len(list(self.reddit.get_unread(limit=1)))
|
inbox = len(list(self.reddit.get_unread(limit=1)))
|
||||||
try:
|
try:
|
||||||
if inbox > 0:
|
if inbox > 0:
|
||||||
|
|||||||
@@ -93,9 +93,6 @@ class SubmissionPage(BasePage):
|
|||||||
show_notification(self.stdscr, ['Not logged in'])
|
show_notification(self.stdscr, ['Not logged in'])
|
||||||
return
|
return
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
data = self.content.get(self.nav.absolute_index)
|
data = self.content.get(self.nav.absolute_index)
|
||||||
if data['type'] == 'Submission':
|
if data['type'] == 'Submission':
|
||||||
content = data['text']
|
content = data['text']
|
||||||
@@ -131,9 +128,6 @@ class SubmissionPage(BasePage):
|
|||||||
def delete_comment(self):
|
def delete_comment(self):
|
||||||
"Delete a comment as long as it is not the current submission"
|
"Delete a comment as long as it is not the current submission"
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
if self.nav.absolute_index != -1:
|
if self.nav.absolute_index != -1:
|
||||||
self.delete()
|
self.delete()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import requests
|
|||||||
from .exceptions import SubredditError, AccountError
|
from .exceptions import SubredditError, AccountError
|
||||||
from .page import BasePage, Navigator, BaseController
|
from .page import BasePage, Navigator, BaseController
|
||||||
from .submission import SubmissionPage
|
from .submission import SubmissionPage
|
||||||
from .subscriptions import SubscriptionPage
|
from .subscription import SubscriptionPage
|
||||||
from .content import SubredditContent
|
from .content import SubredditContent
|
||||||
from .helpers import open_browser, open_editor, strip_subreddit_url
|
from .helpers import open_browser, open_editor, strip_subreddit_url
|
||||||
from .docs import SUBMISSION_FILE
|
from .docs import SUBMISSION_FILE
|
||||||
@@ -54,9 +54,6 @@ class SubredditPage(BasePage):
|
|||||||
def refresh_content(self, name=None, order=None):
|
def refresh_content(self, name=None, order=None):
|
||||||
"Re-download all submissions and reset the page index"
|
"Re-download all submissions and reset the page index"
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
name = name or self.content.name
|
name = name or self.content.name
|
||||||
order = order or self.content.order
|
order = order or self.content.order
|
||||||
|
|
||||||
@@ -136,9 +133,6 @@ class SubredditPage(BasePage):
|
|||||||
show_notification(self.stdscr, ['Not logged in'])
|
show_notification(self.stdscr, ['Not logged in'])
|
||||||
return
|
return
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
# Strips the subreddit to just the name
|
# Strips the subreddit to just the name
|
||||||
# Make sure it is a valid subreddit for submission
|
# Make sure it is a valid subreddit for submission
|
||||||
subreddit = self.reddit.get_subreddit(self.content.name)
|
subreddit = self.reddit.get_subreddit(self.content.name)
|
||||||
@@ -180,9 +174,6 @@ class SubredditPage(BasePage):
|
|||||||
show_notification(self.stdscr, ['Not logged in'])
|
show_notification(self.stdscr, ['Not logged in'])
|
||||||
return
|
return
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
# Open subscriptions page
|
# Open subscriptions page
|
||||||
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
|
page = SubscriptionPage(self.stdscr, self.reddit, self.oauth)
|
||||||
page.loop()
|
page.loop()
|
||||||
|
|||||||
@@ -38,9 +38,6 @@ class SubscriptionPage(BasePage):
|
|||||||
def refresh_content(self):
|
def refresh_content(self):
|
||||||
"Re-download all subscriptions and reset the page index"
|
"Re-download all subscriptions and reset the page index"
|
||||||
|
|
||||||
# Refresh access token if expired
|
|
||||||
self.oauth.refresh()
|
|
||||||
|
|
||||||
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
|
self.content = SubscriptionContent.from_user(self.reddit, self.loader)
|
||||||
self.nav = Navigator(self.content.get)
|
self.nav = Navigator(self.content.get)
|
||||||
|
|
||||||
Reference in New Issue
Block a user