Files
tuir/rtv/packages/praw/internal.py
2017-03-28 21:33:10 -07:00

272 lines
10 KiB
Python

# This file is part of PRAW.
#
# PRAW is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# PRAW is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# PRAW. If not, see <http://www.gnu.org/licenses/>.
"""Internal helper functions.
The functions in this module are not to be relied upon by third-parties.
"""
from __future__ import print_function, unicode_literals
import os
import re
import six
import sys
from requests import Request, codes, exceptions
from requests.compat import urljoin
from praw.decorators import restrict_access
from praw.errors import (ClientException, HTTPException, Forbidden, NotFound,
InvalidSubreddit, OAuthException,
OAuthInsufficientScope, OAuthInvalidToken,
RedirectException)
from warnings import warn
try:
from OpenSSL import __version__ as _opensslversion
_opensslversionlist = [int(minor) if minor.isdigit() else minor
for minor in _opensslversion.split('.')]
except ImportError:
_opensslversionlist = [0, 15]
MIN_PNG_SIZE = 67
MIN_JPEG_SIZE = 128
MAX_IMAGE_SIZE = 512000
JPEG_HEADER = b'\xff\xd8\xff'
PNG_HEADER = b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a'
RE_REDIRECT = re.compile('(rand(om|nsfw))|about/sticky')
def _get_redditor_listing(subpath=''):
"""Return function to generate Redditor listings."""
def _listing(self, sort='new', time='all', *args, **kwargs):
"""Return a get_content generator for some RedditContentObject type.
:param sort: Specify the sort order of the results if applicable
(one of ``'hot'``, ``'new'``, ``'top'``, ``'controversial'``).
:param time: Specify the time-period to return submissions if
applicable (one of ``'hour'``, ``'day'``, ``'week'``,
``'month'``, ``'year'``, ``'all'``).
The additional parameters are passed directly into
:meth:`.get_content`. Note: the `url` parameter cannot be altered.
"""
kwargs.setdefault('params', {})
kwargs['params'].setdefault('sort', sort)
kwargs['params'].setdefault('t', time)
url = urljoin(self._url, subpath) # pylint: disable=W0212
return self.reddit_session.get_content(url, *args, **kwargs)
return _listing
def _get_sorter(subpath='', **defaults):
"""Return function to generate specific subreddit Submission listings."""
@restrict_access(scope='read')
def _sorted(self, *args, **kwargs):
"""Return a get_content generator for some RedditContentObject type.
The additional parameters are passed directly into
:meth:`.get_content`. Note: the `url` parameter cannot be altered.
"""
if not kwargs.get('params'):
kwargs['params'] = {}
for key, value in six.iteritems(defaults):
kwargs['params'].setdefault(key, value)
url = urljoin(self._url, subpath) # pylint: disable=W0212
return self.reddit_session.get_content(url, *args, **kwargs)
return _sorted
def _image_type(image):
size = os.path.getsize(image.name)
if size < MIN_PNG_SIZE:
raise ClientException('png image is too small.')
if size > MAX_IMAGE_SIZE:
raise ClientException('`image` is too big. Max: {0} bytes'
.format(MAX_IMAGE_SIZE))
first_bytes = image.read(MIN_PNG_SIZE)
image.seek(0)
if first_bytes.startswith(PNG_HEADER):
return 'png'
elif first_bytes.startswith(JPEG_HEADER):
if size < MIN_JPEG_SIZE:
raise ClientException('jpeg image is too small.')
return 'jpg'
raise ClientException('`image` must be either jpg or png.')
def _modify_relationship(relationship, unlink=False, is_sub=False):
"""Return a function for relationship modification.
Used to support friending (user-to-user), as well as moderating,
contributor creating, and banning (user-to-subreddit).
"""
# The API uses friend and unfriend to manage all of these relationships.
url_key = 'unfriend' if unlink else 'friend'
if relationship == 'friend':
access = {'scope': None, 'login': True}
elif relationship == 'moderator':
access = {'scope': 'modothers'}
elif relationship in ['banned', 'contributor', 'muted']:
access = {'scope': 'modcontributors'}
elif relationship in ['wikibanned', 'wikicontributor']:
access = {'scope': ['modcontributors', 'modwiki']}
else:
access = {'scope': None, 'mod': True}
@restrict_access(**access)
def do_relationship(thing, user, **kwargs):
data = {'name': six.text_type(user),
'type': relationship}
data.update(kwargs)
if is_sub:
data['r'] = six.text_type(thing)
else:
data['container'] = thing.fullname
session = thing.reddit_session
if relationship == 'moderator':
session.evict(session.config['moderators'].format(
subreddit=six.text_type(thing)))
url = session.config[url_key]
return session.request_json(url, data=data)
return do_relationship
def _prepare_request(reddit_session, url, params, data, auth, files,
method=None):
"""Return a requests Request object that can be "prepared"."""
# Requests using OAuth for authorization must switch to using the oauth
# domain.
if getattr(reddit_session, '_use_oauth', False):
bearer = 'bearer {0}'.format(reddit_session.access_token)
headers = {'Authorization': bearer}
config = reddit_session.config
for prefix in (config.api_url, config.permalink_url):
if url.startswith(prefix):
if config.log_requests >= 1:
msg = 'substituting {0} for {1} in url\n'.format(
config.oauth_url, prefix)
sys.stderr.write(msg)
url = config.oauth_url + url[len(prefix):]
break
else:
headers = {}
headers.update(reddit_session.http.headers)
if method:
pass
elif data or files:
method = 'POST'
else:
method = 'GET'
# Log the request if logging is enabled
if reddit_session.config.log_requests >= 1:
sys.stderr.write('{0}: {1}\n'.format(method, url))
if reddit_session.config.log_requests >= 2:
if params:
sys.stderr.write('params: {0}\n'.format(params))
if data:
sys.stderr.write('data: {0}\n'.format(data))
if auth:
sys.stderr.write('auth: {0}\n'.format(auth))
# Prepare request
request = Request(method=method, url=url, headers=headers, params=params,
auth=auth, cookies=reddit_session.http.cookies)
if method == 'GET':
return request
# Most POST requests require adding `api_type` and `uh` to the data.
if data is True:
data = {}
if isinstance(data, dict):
if not auth:
data.setdefault('api_type', 'json')
if reddit_session.modhash:
data.setdefault('uh', reddit_session.modhash)
else:
request.headers.setdefault('Content-Type', 'application/json')
request.data = data
request.files = files
return request
def _raise_redirect_exceptions(response):
"""Return the new url or None if there are no redirects.
Raise exceptions if appropriate.
"""
if response.status_code not in [301, 302, 307]:
return None
new_url = urljoin(response.url, response.headers['location'])
if 'reddits/search' in new_url: # Handle non-existent subreddit
subreddit = new_url.rsplit('=', 1)[1]
raise InvalidSubreddit('`{0}` is not a valid subreddit'
.format(subreddit))
elif not RE_REDIRECT.search(response.url):
raise RedirectException(response.url, new_url)
return new_url
def _raise_response_exceptions(response):
"""Raise specific errors on some status codes."""
if not response.ok and 'www-authenticate' in response.headers:
msg = response.headers['www-authenticate']
if 'insufficient_scope' in msg:
raise OAuthInsufficientScope('insufficient_scope', response.url)
elif 'invalid_token' in msg:
raise OAuthInvalidToken('invalid_token', response.url)
else:
raise OAuthException(msg, response.url)
if response.status_code == codes.forbidden: # pylint: disable=E1101
raise Forbidden(_raw=response)
elif response.status_code == codes.not_found: # pylint: disable=E1101
raise NotFound(_raw=response)
else:
try:
response.raise_for_status() # These should all be directly mapped
except exceptions.HTTPError as exc:
raise HTTPException(_raw=exc.response)
def _to_reddit_list(arg):
"""Return an argument converted to a reddit-formatted list.
The returned format is a comma deliminated list. Each element is a string
representation of an object. Either given as a string or as an object that
is then converted to its string representation.
"""
if (isinstance(arg, six.string_types) or not (
hasattr(arg, "__getitem__") or hasattr(arg, "__iter__"))):
return six.text_type(arg)
else:
return ','.join(six.text_type(a) for a in arg)
def _warn_pyopenssl():
"""Warn the user against faulty versions of pyOpenSSL."""
if _opensslversionlist < [0, 15]: # versions >= 0.15 are fine
warn(RuntimeWarning(
"pyOpenSSL {0} may be incompatible with praw if validating"
"ssl certificates, which is on by default.\nSee https://"
"github.com/praw/pull/625 for more information".format(
_opensslversion)
))