Making real progress this time.
This commit is contained in:
117
rtv/oauth.py
117
rtv/oauth.py
@@ -1,62 +1,91 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import string
|
||||
import codecs
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tornado import gen, ioloop, web, httpserver
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
||||
from .config import TEMPLATE
|
||||
from . import docs
|
||||
from .config import TEMPLATES
|
||||
|
||||
|
||||
class OAuthHandler(web.RequestHandler):
|
||||
"""
|
||||
Intercepts the redirect that Reddit sends the user to after they verify or
|
||||
deny the application access.
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
The GET should supply 3 request params:
|
||||
state: Unique id that was supplied by us at the beginning of the
|
||||
process to verify that the session matches.
|
||||
code: Code that we can use to generate the refresh token.
|
||||
error: If an error occurred, it will be placed here.
|
||||
"""
|
||||
INDEX = os.path.join(TEMPLATES, 'index.html')
|
||||
|
||||
def initialize(self, display=None, params=None):
|
||||
self.display = display
|
||||
self.params = params
|
||||
|
||||
def get(self):
|
||||
self.params['state'] = self.get_argument('state', default=None)
|
||||
self.params['code'] = self.get_argument('code', default=None)
|
||||
self.params['error'] = self.get_argument('error', default=None)
|
||||
class OAuthHandler(BaseHTTPRequestHandler):
|
||||
|
||||
self.render('index.html', **self.params)
|
||||
params = {'state': None, 'code': None, 'error': None}
|
||||
|
||||
complete = self.params['state'] and self.params['code']
|
||||
if complete or self.params['error']:
|
||||
# Stop IOLoop if using a background browser such as firefox
|
||||
if self.display:
|
||||
ioloop.IOLoop.current().stop()
|
||||
def do_GET(self):
|
||||
|
||||
parsed_path = urlparse(self.path)
|
||||
if parsed_path.path != '/':
|
||||
self.send_error(404)
|
||||
|
||||
qs = parse_qs(parsed_path.query)
|
||||
self.params['state'] = qs['state'][0] if 'state' in qs else None
|
||||
self.params['code'] = qs['code'][0] if 'code' in qs else None
|
||||
self.params['error'] = qs['error'][0] if 'error' in qs else None
|
||||
|
||||
body = self.build_body()
|
||||
|
||||
# send_response also sets the Server and Date headers
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'text/html; charset=UTF-8')
|
||||
self.send_header('Content-Length', len(body))
|
||||
self.end_headers()
|
||||
|
||||
self.wfile.write(body)
|
||||
|
||||
# Shutdown the server after serving the request
|
||||
# http://stackoverflow.com/a/22533929
|
||||
thread = threading.Thread(target=self.server.shutdown)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
def build_body(self, template_file=INDEX):
|
||||
|
||||
if self.params['error'] == 'access_denied':
|
||||
message = docs.OAUTH_ACCESS_DENIED
|
||||
elif self.params['error'] is not None:
|
||||
message = docs.OAUTH_ERROR.format(error=self.params['error'])
|
||||
elif self.params['state'] is None or self.params['code'] is None:
|
||||
message = docs.OAUTH_INVALID
|
||||
else:
|
||||
message = docs.OAUTH_SUCCESS
|
||||
|
||||
with codecs.open(template_file, 'r', 'utf-8') as fp:
|
||||
index_text = fp.read()
|
||||
|
||||
body = string.Template(index_text).substitute(message=message)
|
||||
body = codecs.encode(body, 'utf-8')
|
||||
return body
|
||||
|
||||
def log_message(self, format, *args):
|
||||
_logger.debug(format, *args)
|
||||
|
||||
|
||||
class OAuthHelper(object):
|
||||
|
||||
params = OAuthHandler.params
|
||||
|
||||
def __init__(self, reddit, term, config):
|
||||
|
||||
self.term = term
|
||||
self.reddit = reddit
|
||||
self.config = config
|
||||
|
||||
self.http_server = None
|
||||
self.params = {'state': None, 'code': None, 'error': None}
|
||||
|
||||
# Initialize Tornado webapp
|
||||
# Pass a mutable params object so the request handler can modify it
|
||||
kwargs = {'display': self.term.display, 'params': self.params}
|
||||
routes = [('/', OAuthHandler, kwargs)]
|
||||
self.callback_app = web.Application(
|
||||
routes, template_path=TEMPLATE)
|
||||
address = ('', self.config['oauth_redirect_port'])
|
||||
self.server = HTTPServer(address, OAuthHandler)
|
||||
|
||||
self.reddit.set_oauth_app_info(
|
||||
self.config['oauth_client_id'],
|
||||
@@ -79,14 +108,6 @@ class OAuthHelper(object):
|
||||
self.config.refresh_token)
|
||||
return
|
||||
|
||||
# https://github.com/tornadoweb/tornado/issues/1420
|
||||
io = ioloop.IOLoop.current()
|
||||
|
||||
# Start the authorization callback server
|
||||
if self.http_server is None:
|
||||
self.http_server = httpserver.HTTPServer(self.callback_app)
|
||||
self.http_server.listen(self.config['oauth_redirect_port'])
|
||||
|
||||
state = uuid.uuid4().hex
|
||||
authorize_url = self.reddit.get_authorize_url(
|
||||
state, scope=self.config['oauth_scope'], refreshable=True)
|
||||
@@ -97,7 +118,7 @@ class OAuthHelper(object):
|
||||
# point we continue and check the callback params.
|
||||
with self.term.loader('Opening browser for authorization'):
|
||||
self.term.open_browser(authorize_url)
|
||||
io.start()
|
||||
self.server.serve_forever()
|
||||
if self.term.loader.exception:
|
||||
return
|
||||
else:
|
||||
@@ -138,10 +159,4 @@ class OAuthHelper(object):
|
||||
|
||||
def clear_oauth_data(self):
|
||||
self.reddit.clear_authentication()
|
||||
self.config.delete_refresh_token()
|
||||
|
||||
@gen.coroutine
|
||||
def _async_open_browser(self, url):
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
yield executor.submit(self.term.open_browser, url)
|
||||
ioloop.IOLoop.current().stop()
|
||||
self.config.delete_refresh_token()
|
||||
Reference in New Issue
Block a user