326 lines
12 KiB
Python
326 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import unicode_literals
|
|
|
|
import os
|
|
import time
|
|
import uuid
|
|
import string
|
|
import codecs
|
|
import logging
|
|
import threading
|
|
|
|
#pylint: disable=import-error
|
|
from six.moves.urllib.parse import urlparse, parse_qs
|
|
from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
|
|
|
|
from . import docs
|
|
from .config import TEMPLATES
|
|
from .exceptions import InvalidRefreshToken
|
|
from .packages.praw.errors import HTTPException, OAuthException
|
|
from .packages.praw.handlers import DefaultHandler
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
INDEX = os.path.join(TEMPLATES, 'index.html')
|
|
|
|
|
|
class OAuthHandler(BaseHTTPRequestHandler):
|
|
|
|
# params are stored as a global because we don't have control over what
|
|
# gets passed into the handler __init__. These will be accessed by the
|
|
# OAuthHelper class.
|
|
params = {'state': None, 'code': None, 'error': None}
|
|
shutdown_on_request = True
|
|
|
|
def do_GET(self):
|
|
"""
|
|
Accepts GET requests to http://localhost:6500/, and stores the query
|
|
params in the global dict. If shutdown_on_request is true, stop the
|
|
server after the first successful request.
|
|
|
|
The http request may contain the following query params:
|
|
- state : unique identifier, should match what we passed to reddit
|
|
- code : code that can be exchanged for a refresh token
|
|
- error : if provided, the OAuth error that occurred
|
|
"""
|
|
|
|
parsed_path = urlparse(self.path)
|
|
if parsed_path.path != '/':
|
|
self.send_error(404)
|
|
|
|
qs = parse_qs(parsed_path.query)
|
|
self.params['state'] = qs['state'][0] if 'state' in qs else None
|
|
self.params['code'] = qs['code'][0] if 'code' in qs else None
|
|
self.params['error'] = qs['error'][0] if 'error' in qs else None
|
|
|
|
body = self.build_body()
|
|
|
|
# send_response also sets the Server and Date headers
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'text/html; charset=UTF-8')
|
|
self.send_header('Content-Length', len(body))
|
|
self.end_headers()
|
|
|
|
self.wfile.write(body)
|
|
|
|
if self.shutdown_on_request:
|
|
# Shutdown the server after serving the request
|
|
# http://stackoverflow.com/a/22533929
|
|
thread = threading.Thread(target=self.server.shutdown)
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
def log_message(self, format, *args):
|
|
"""
|
|
Redirect logging to our own handler instead of stdout
|
|
"""
|
|
_logger.debug(format, *args)
|
|
|
|
def build_body(self, template_file=INDEX):
|
|
"""
|
|
Params:
|
|
template_file (text): Path to an index.html template
|
|
|
|
Returns:
|
|
body (bytes): THe utf-8 encoded document body
|
|
"""
|
|
|
|
if self.params['error'] == 'access_denied':
|
|
message = docs.OAUTH_ACCESS_DENIED
|
|
elif self.params['error'] is not None:
|
|
message = docs.OAUTH_ERROR.format(error=self.params['error'])
|
|
elif self.params['state'] is None or self.params['code'] is None:
|
|
message = docs.OAUTH_INVALID
|
|
else:
|
|
message = docs.OAUTH_SUCCESS
|
|
|
|
with codecs.open(template_file, 'r', 'utf-8') as fp:
|
|
index_text = fp.read()
|
|
|
|
body = string.Template(index_text).substitute(message=message)
|
|
body = codecs.encode(body, 'utf-8')
|
|
return body
|
|
|
|
|
|
class OAuthHelper(object):
|
|
|
|
params = OAuthHandler.params
|
|
|
|
def __init__(self, reddit, term, config):
|
|
|
|
self.term = term
|
|
self.reddit = reddit
|
|
self.config = config
|
|
|
|
# Wait to initialize the server, we don't want to reserve the port
|
|
# unless we know that the server needs to be used.
|
|
self.server = None
|
|
|
|
self.reddit.set_oauth_app_info(
|
|
self.config['oauth_client_id'],
|
|
self.config['oauth_client_secret'],
|
|
self.config['oauth_redirect_uri'])
|
|
|
|
# Reddit's mobile website works better on terminal browsers
|
|
if not self.term.display:
|
|
if '.compact' not in self.reddit.config.API_PATHS['authorize']:
|
|
self.reddit.config.API_PATHS['authorize'] += '.compact'
|
|
|
|
def authorize(self):
|
|
|
|
self.params.update(state=None, code=None, error=None)
|
|
|
|
# If we already have a token, request new access credentials
|
|
if self.config.refresh_token:
|
|
with self.term.loader('Logging in'):
|
|
try:
|
|
self.reddit.refresh_access_information(
|
|
self.config.refresh_token)
|
|
except (HTTPException, OAuthException) as e:
|
|
# Reddit didn't accept the refresh-token
|
|
# This appears to throw a generic 400 error instead of the
|
|
# more specific invalid_token message that it used to send
|
|
if isinstance(e, HTTPException):
|
|
if e._raw.status_code != 400:
|
|
# No special handling if the error is something
|
|
# temporary like a 5XX.
|
|
raise e
|
|
|
|
# Otherwise we know the token is bad, so we can remove it.
|
|
_logger.exception(e)
|
|
self.clear_oauth_data()
|
|
raise InvalidRefreshToken(
|
|
' Invalid user credentials!\n'
|
|
'The cached refresh token has been removed')
|
|
return
|
|
|
|
state = uuid.uuid4().hex
|
|
authorize_url = self.reddit.get_authorize_url(
|
|
state, scope=self.config['oauth_scope'], refreshable=True)
|
|
|
|
if self.server is None:
|
|
address = ('', self.config['oauth_redirect_port'])
|
|
self.server = HTTPServer(address, OAuthHandler)
|
|
|
|
if self.term.display:
|
|
# Open a background browser (e.g. firefox) which is non-blocking.
|
|
# The server will block until it responds to its first request,
|
|
# at which point we can check the callback params.
|
|
OAuthHandler.shutdown_on_request = True
|
|
with self.term.loader('Opening browser for authorization'):
|
|
self.term.open_browser(authorize_url)
|
|
self.server.serve_forever()
|
|
if self.term.loader.exception:
|
|
# Don't need to call server.shutdown() because serve_forever()
|
|
# is wrapped in a try-finally that doees it for us.
|
|
return
|
|
else:
|
|
# Open the terminal webbrowser in a background thread and wait
|
|
# while for the user to close the process. Once the process is
|
|
# closed, the iloop is stopped and we can check if the user has
|
|
# hit the callback URL.
|
|
OAuthHandler.shutdown_on_request = False
|
|
with self.term.loader('Redirecting to reddit', delay=0):
|
|
# This load message exists to provide user feedback
|
|
time.sleep(1)
|
|
|
|
thread = threading.Thread(target=self.server.serve_forever)
|
|
thread.daemon = True
|
|
thread.start()
|
|
try:
|
|
self.term.open_browser(authorize_url)
|
|
except Exception as e:
|
|
# If an exception is raised it will be seen by the thread
|
|
# so we don't need to explicitly shutdown() the server
|
|
_logger.exception(e)
|
|
self.term.show_notification('Browser Error', style='error')
|
|
else:
|
|
self.server.shutdown()
|
|
finally:
|
|
thread.join()
|
|
|
|
if self.params['error'] == 'access_denied':
|
|
self.term.show_notification('Denied access', style='error')
|
|
return
|
|
elif self.params['error']:
|
|
self.term.show_notification('Authentication error', style='error')
|
|
return
|
|
elif self.params['state'] is None:
|
|
# Something went wrong but it's not clear what happened
|
|
return
|
|
elif self.params['state'] != state:
|
|
self.term.show_notification('UUID mismatch', style='error')
|
|
return
|
|
|
|
with self.term.loader('Logging in'):
|
|
info = self.reddit.get_access_information(self.params['code'])
|
|
if self.term.loader.exception:
|
|
return
|
|
|
|
message = 'Welcome {}!'.format(self.reddit.user.name)
|
|
self.term.show_notification(message)
|
|
|
|
self.config.refresh_token = info['refresh_token']
|
|
if self.config['persistent']:
|
|
self.config.save_refresh_token()
|
|
|
|
def clear_oauth_data(self):
|
|
self.reddit.clear_authentication()
|
|
self.config.delete_refresh_token()
|
|
|
|
|
|
def fix_cache(func):
|
|
def wraps(self, _cache_key, _cache_ignore, *args, **kwargs):
|
|
if _cache_key:
|
|
# Pop the request's session cookies from the cache key.
|
|
# These appear to be unreliable and change with every
|
|
# request. Also, with the introduction of OAuth I don't think
|
|
# that cookies are being used to store anything that
|
|
# differentiates API requests anyways
|
|
url, items = _cache_key
|
|
_cache_key = (url, (items[0], items[1], items[3], items[4]))
|
|
|
|
if kwargs['request'].method != 'GET':
|
|
# Why were POST/PUT/DELETE requests ever being cached???
|
|
_cache_ignore = True
|
|
|
|
return func(self, _cache_key, _cache_ignore, *args, **kwargs)
|
|
return wraps
|
|
|
|
|
|
class OAuthRateLimitHandler(DefaultHandler):
|
|
"""Custom PRAW request handler for rate-limiting requests.
|
|
|
|
This is an alternative to PRAW 3's DefaultHandler that uses
|
|
Reddit's modern API guidelines to rate-limit requests based
|
|
on the X-Ratelimit-* headers returned from Reddit.
|
|
|
|
References:
|
|
https://github.com/reddit/reddit/wiki/API
|
|
https://github.com/praw-dev/prawcore/blob/master/prawcore/rate_limit.py
|
|
"""
|
|
|
|
next_request_timestamp = None
|
|
|
|
def delay(self):
|
|
"""
|
|
Pause before making the next HTTP request.
|
|
"""
|
|
if self.next_request_timestamp is None:
|
|
return
|
|
|
|
sleep_seconds = self.next_request_timestamp - time.time()
|
|
if sleep_seconds <= 0:
|
|
return
|
|
time.sleep(sleep_seconds)
|
|
|
|
def update(self, response_headers):
|
|
"""
|
|
Update the state of the rate limiter based on the response headers:
|
|
|
|
X-Ratelimit-Used: Approximate number of requests used this period
|
|
X-Ratelimit-Remaining: Approximate number of requests left to use
|
|
X-Ratelimit-Reset: Approximate number of seconds to end of period
|
|
|
|
PRAW 5's rate limiting logic is structured for making hundreds of
|
|
evenly-spaced API requests, which makes sense for running something
|
|
like a bot or crawler.
|
|
|
|
This handler's logic, on the other hand, is geared more towards
|
|
interactive usage. It allows for short, sporadic bursts of requests.
|
|
The assumption is that actual users browsing reddit shouldn't ever be
|
|
in danger of hitting the rate limit. If they do hit the limit, they
|
|
will be cutoff until the period resets.
|
|
"""
|
|
|
|
if 'x-ratelimit-remaining' not in response_headers:
|
|
# This could be because the API returned an error response, or it
|
|
# could be because we're using something like read-only credentials
|
|
# which Reddit doesn't appear to care about rate limiting.
|
|
return
|
|
|
|
used = float(response_headers['x-ratelimit-used'])
|
|
remaining = float(response_headers['x-ratelimit-remaining'])
|
|
seconds_to_reset = int(response_headers['x-ratelimit-reset'])
|
|
_logger.debug('Rate limit: %s used, %s remaining, %s reset',
|
|
used, remaining, seconds_to_reset)
|
|
|
|
if remaining <= 0:
|
|
self.next_request_timestamp = time.time() + seconds_to_reset
|
|
else:
|
|
self.next_request_timestamp = None
|
|
|
|
@fix_cache
|
|
@DefaultHandler.with_cache
|
|
def request(self, request, proxies, timeout, verify, **_):
|
|
|
|
settings = self.http.merge_environment_settings(
|
|
request.url, proxies, False, verify, None)
|
|
|
|
self.delay()
|
|
response = self.http.send(
|
|
request, timeout=timeout, allow_redirects=False, **settings)
|
|
self.update(response.headers)
|
|
return response
|