diff --git a/rtv/oauth.py b/rtv/oauth.py index ec53811..ca2e7d9 100644 --- a/rtv/oauth.py +++ b/rtv/oauth.py @@ -24,6 +24,18 @@ _logger = logging.getLogger(__name__) INDEX = os.path.join(TEMPLATES, 'index.html') +class OAuthHTTPServer(HTTPServer): + + def handle_error(self, request, client_address): + """ + The default HTTPServer's error handler prints the request traceback + to stdout, which breaks the curses display. + + Override it to log to a file instead. + """ + _logger.exception('Error processing request in OAuth HTTP Server') + + class OAuthHandler(BaseHTTPRequestHandler): # params are stored as a global because we don't have control over what @@ -160,7 +172,7 @@ class OAuthHelper(object): if self.server is None: address = ('', self.config['oauth_redirect_port']) - self.server = HTTPServer(address, OAuthHandler) + self.server = OAuthHTTPServer(address, OAuthHandler) if self.term.display: # Open a background browser (e.g. firefox) which is non-blocking. diff --git a/tests/conftest.py b/tests/conftest.py index 69e371d..a72b1d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,9 +10,8 @@ from functools import partial import pytest from vcr import VCR from six.moves.urllib.parse import urlparse, parse_qs -from six.moves.BaseHTTPServer import HTTPServer -from rtv.oauth import OAuthHelper, OAuthHandler +from rtv.oauth import OAuthHelper, OAuthHandler, OAuthHTTPServer from rtv.content import RequestHeaderRateLimiter from rtv.config import Config from rtv.packages import praw @@ -216,7 +215,7 @@ def oauth(reddit, terminal, config): @pytest.yield_fixture() def oauth_server(): # Start the OAuth server on a random port in the background - server = HTTPServer(('', 0), OAuthHandler) + server = OAuthHTTPServer(('', 0), OAuthHandler) server.url = 'http://{0}:{1}/'.format(*server.server_address) thread = threading.Thread(target=server.serve_forever) thread.start() diff --git a/tests/test_oauth.py b/tests/test_oauth.py index a98b6b0..2a988fb 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -117,7 +117,7 @@ def test_oauth_authorize(oauth, reddit, stdscr, refresh_token): # function in the destination oauth module and not the helpers module with mock.patch('uuid.UUID.hex', new_callable=mock.PropertyMock) as uuid, \ mock.patch('rtv.terminal.Terminal.open_browser') as open_browser, \ - mock.patch('rtv.oauth.HTTPServer') as http_server, \ + mock.patch('rtv.oauth.OAuthHTTPServer') as http_server, \ mock.patch.object(oauth.reddit, 'user'), \ mock.patch('time.sleep'):