Bundling praw v3 with rtv

This commit is contained in:
Michael Lazar
2017-03-28 21:33:10 -07:00
parent 31024e1a6c
commit 6b1eab1a97
24 changed files with 6930 additions and 23 deletions

View File

@@ -9,10 +9,10 @@ import logging
import warnings
import six
import praw
import requests
from . import docs
from .packages import praw
from .config import Config, copy_default_config, copy_default_mailcap
from .oauth import OAuthHelper
from .terminal import Terminal

View File

@@ -6,11 +6,11 @@ import logging
from datetime import datetime
import six
import praw
from praw.errors import InvalidSubreddit
from kitchen.text.display import wrap
from . import exceptions
from .packages import praw
from .packages.praw.errors import InvalidSubreddit
_logger = logging.getLogger(__name__)

View File

@@ -14,10 +14,10 @@ import curses.ascii
from contextlib import contextmanager
import six
import praw
import requests
from . import exceptions
from .packages import praw
_logger = logging.getLogger(__name__)

23
rtv/packages/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
"""
This stub allows the end-user to fallback to their system installation of praw
if the bundled package missing. This technique was inspired by the requests
library and how it handles dependencies.
Reference:
https://github.com/kennethreitz/requests/blob/master/requests/packages/__init__.py
"""
from __future__ import absolute_import
import sys
__praw_hash__ = 'a632ff005fc09e74a8d3d276adc10aa92638962c'
try:
from . import praw
except ImportError:
import praw
if not praw.__version__.startswith('3.'):
msg = 'Invalid PRAW version {0}, exiting'.format(praw.__version__)
raise RuntimeError(msg)
sys.modules['%s.praw' % __name__] = praw

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
"""Internal helper functions used by praw.decorators."""
import inspect
from requests.compat import urljoin
import six
import sys
def _get_captcha(reddit_session, captcha_id):
"""Prompt user for captcha solution and return a prepared result."""
url = urljoin(reddit_session.config['captcha'],
captcha_id + '.png')
sys.stdout.write('Captcha URL: {0}\nCaptcha: '.format(url))
sys.stdout.flush()
raw = sys.stdin.readline()
if not raw: # stdin has reached the end of file
# Trigger exception raising next time through. The request is
# cached so this will not require and extra request and delay.
sys.stdin.close()
return None
return {'iden': captcha_id, 'captcha': raw.strip()}
def _is_mod_of_all(user, subreddit):
mod_subs = user.get_cached_moderated_reddits()
subs = six.text_type(subreddit).lower().split('+')
return all(sub in mod_subs for sub in subs)
def _make_func_args(function):
if six.PY3 and not hasattr(sys, 'pypy_version_info'):
# CPython3 uses inspect.signature(), not inspect.getargspec()
# see #551 and #541 for more info
func_items = inspect.signature(function).parameters.items()
func_args = [name for name, param in func_items
if param.kind == param.POSITIONAL_OR_KEYWORD]
else:
func_args = inspect.getargspec(function).args
return func_args

View File

@@ -0,0 +1,294 @@
# This file is part of PRAW.
#
# PRAW is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# PRAW is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# PRAW. If not, see <http://www.gnu.org/licenses/>.
"""
Decorators.
They mainly do two things: ensure API guidelines are followed and
prevent unnecessary failed API requests by testing that the call can be made
first. Also, they can limit the length of output strings and parse json
response for certain errors.
"""
from __future__ import print_function, unicode_literals
import decorator
import six
import sys
from functools import wraps
from praw.decorator_helpers import (
_get_captcha,
_is_mod_of_all,
_make_func_args
)
from praw import errors
from warnings import filterwarnings, warn
# Enable deprecation warnings from this module
filterwarnings('default', category=DeprecationWarning,
module='^praw\.decorators$')
def alias_function(function, class_name):
"""Create a RedditContentObject function mapped to a BaseReddit function.
The BaseReddit classes define the majority of the API's functions. The
first argument for many of these functions is the RedditContentObject that
they operate on. This factory returns functions appropriate to be called on
a RedditContent object that maps to the corresponding BaseReddit function.
"""
@wraps(function)
def wrapped(self, *args, **kwargs):
func_args = _make_func_args(function)
if 'subreddit' in func_args and func_args.index('subreddit') != 1:
# Only happens for search
kwargs['subreddit'] = self
return function(self.reddit_session, *args, **kwargs)
else:
return function(self.reddit_session, self, *args, **kwargs)
# Only grab the short-line doc and add a link to the complete doc
if wrapped.__doc__ is not None:
wrapped.__doc__ = wrapped.__doc__.split('\n', 1)[0]
wrapped.__doc__ += ('\n\nSee :meth:`.{0}.{1}` for complete usage. '
'Note that you should exclude the subreddit '
'parameter when calling this convenience method.'
.format(class_name, function.__name__))
# Don't hide from sphinx as this is a parameter modifying decorator
return wrapped
def deprecated(msg=''):
"""Deprecate decorated method."""
@decorator.decorator
def wrap(function, *args, **kwargs):
if not kwargs.pop('disable_warning', False):
warn(msg, DeprecationWarning)
return function(*args, **kwargs)
return wrap
@decorator.decorator
def limit_chars(function, *args, **kwargs):
"""Truncate the string returned from a function and return the result."""
output_chars_limit = args[0].reddit_session.config.output_chars_limit
output_string = function(*args, **kwargs)
if -1 < output_chars_limit < len(output_string):
output_string = output_string[:output_chars_limit - 3] + '...'
return output_string
@decorator.decorator
def oauth_generator(function, *args, **kwargs):
"""Set the _use_oauth keyword argument to True when appropriate.
This is needed because generator functions may be called at anytime, and
PRAW relies on the Reddit._use_oauth value at original call time to know
when to make OAuth requests.
Returned data is not modified.
"""
if getattr(args[0], '_use_oauth', False):
kwargs['_use_oauth'] = True
return function(*args, **kwargs)
@decorator.decorator
def raise_api_exceptions(function, *args, **kwargs):
"""Raise client side exception(s) when present in the API response.
Returned data is not modified.
"""
try:
return_value = function(*args, **kwargs)
except errors.HTTPException as exc:
if exc._raw.status_code != 400: # pylint: disable=W0212
raise # Unhandled HTTPErrors
try: # Attempt to convert v1 errors into older format (for now)
data = exc._raw.json() # pylint: disable=W0212
assert len(data) == 2
return_value = {'errors': [(data['reason'],
data['explanation'], '')]}
except Exception:
raise exc
if isinstance(return_value, dict):
if return_value.get('error') == 304: # Not modified exception
raise errors.NotModified(return_value)
elif return_value.get('errors'):
error_list = []
for error_type, msg, value in return_value['errors']:
if error_type in errors.ERROR_MAPPING:
if error_type == 'RATELIMIT':
args[0].evict(args[1])
error_class = errors.ERROR_MAPPING[error_type]
else:
error_class = errors.APIException
error_list.append(error_class(error_type, msg, value,
return_value))
if len(error_list) == 1:
raise error_list[0]
else:
raise errors.ExceptionList(error_list)
return return_value
@decorator.decorator
def require_captcha(function, *args, **kwargs):
"""Return a decorator for methods that require captchas."""
raise_captcha_exception = kwargs.pop('raise_captcha_exception', False)
captcha_id = None
# Get a handle to the reddit session
if hasattr(args[0], 'reddit_session'):
reddit_session = args[0].reddit_session
else:
reddit_session = args[0]
while True:
try:
if captcha_id:
captcha_answer = _get_captcha(reddit_session, captcha_id)
# When the method is being decorated, all of its default
# parameters become part of this *args tuple. This means that
# *args currently contains a None where the captcha answer
# needs to go. If we put the captcha in the **kwargs,
# we get a TypeError for having two values of the same param.
func_args = _make_func_args(function)
if 'captcha' in func_args:
captcha_index = func_args.index('captcha')
args = list(args)
args[captcha_index] = captcha_answer
else:
kwargs['captcha'] = captcha_answer
return function(*args, **kwargs)
except errors.InvalidCaptcha as exception:
if raise_captcha_exception or \
not hasattr(sys.stdin, 'closed') or sys.stdin.closed:
raise
captcha_id = exception.response['captcha']
def restrict_access(scope, mod=None, login=None, oauth_only=False,
generator_called=False):
"""Restrict function access unless the user has the necessary permissions.
Raises one of the following exceptions when appropriate:
* LoginRequired
* LoginOrOAuthRequired
* the scope attribute will provide the necessary scope name
* ModeratorRequired
* ModeratorOrOAuthRequired
* the scope attribute will provide the necessary scope name
:param scope: Indicate the scope that is required for the API call. None or
False must be passed to indicate that no scope handles the API call.
All scopes save for `read` imply login=True. Scopes with 'mod' in their
name imply mod=True.
:param mod: Indicate that a moderator is required. Implies login=True.
:param login: Indicate that a login is required.
:param oauth_only: Indicate that only OAuth is supported for the function.
:param generator_called: Indicate that the function consists solely of
exhausting one or more oauth_generator wrapped generators. This is
because the oauth_generator itself will determine whether or not to
use the oauth domain.
Returned data is not modified.
This decorator assumes that all mod required functions fit one of these
categories:
* have the subreddit as the first argument (Reddit instance functions) or
have a subreddit keyword argument
* are called upon a subreddit object (Subreddit RedditContentObject)
* are called upon a RedditContent object with attribute subreddit
"""
if not scope and oauth_only:
raise TypeError('`scope` must be set when `oauth_only` is set')
mod = mod is not False and (mod or scope and 'mod' in scope)
login = login is not False and (login or mod or scope and scope != 'read')
@decorator.decorator
def wrap(function, *args, **kwargs):
if args[0] is None: # Occurs with (un)friend
assert login
raise errors.LoginRequired(function.__name__)
# This segment of code uses hasattr to determine what instance type
# the function was called on. We could use isinstance if we wanted
# to import the types at runtime (decorators is used by all the
# types).
if mod:
if hasattr(args[0], 'reddit_session'):
# Defer access until necessary for RedditContentObject.
# This is because scoped sessions may not require this
# attribute to exist, thus it might not be set.
from praw.objects import Subreddit
subreddit = args[0] if isinstance(args[0], Subreddit) \
else False
else:
subreddit = kwargs.get(
'subreddit', args[1] if len(args) > 1 else None)
if subreddit is None: # Try the default value
defaults = six.get_function_defaults(function)
subreddit = defaults[0] if defaults else None
else:
subreddit = None
obj = getattr(args[0], 'reddit_session', args[0])
# This function sets _use_oauth for one time use only.
# Verify that statement is actually true.
assert not obj._use_oauth # pylint: disable=W0212
if scope and obj.has_scope(scope):
obj._use_oauth = not generator_called # pylint: disable=W0212
elif oauth_only:
raise errors.OAuthScopeRequired(function.__name__, scope)
elif login and obj.is_logged_in():
if subreddit is False:
# Now fetch the subreddit attribute. There is no good
# reason for it to not be set during a logged in session.
subreddit = args[0].subreddit
if mod and not _is_mod_of_all(obj.user, subreddit):
if scope:
raise errors.ModeratorOrScopeRequired(
function.__name__, scope)
raise errors.ModeratorRequired(function.__name__)
elif login:
if scope:
raise errors.LoginOrScopeRequired(function.__name__, scope)
raise errors.LoginRequired(function.__name__)
try:
return function(*args, **kwargs)
finally:
obj._use_oauth = False # pylint: disable=W0212
return wrap
@decorator.decorator
def require_oauth(function, *args, **kwargs):
"""Verify that the OAuth functions can be used prior to use.
Returned data is not modified.
"""
if not args[0].has_oauth_app_info:
err_msg = ("The OAuth app config parameters client_id, client_secret "
"and redirect_url must be specified to use this function.")
raise errors.OAuthAppRequired(err_msg)
return function(*args, **kwargs)

487
rtv/packages/praw/errors.py Normal file
View File

@@ -0,0 +1,487 @@
# This file is part of PRAW.
#
# PRAW is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# PRAW is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# PRAW. If not, see <http://www.gnu.org/licenses/>.
"""
Error classes.
Includes two main exceptions: ClientException, when something goes
wrong on our end, and APIExeception for when something goes wrong on the
server side. A number of classes extend these two main exceptions for more
specific exceptions.
"""
from __future__ import print_function, unicode_literals
import inspect
import six
import sys
class PRAWException(Exception):
"""The base PRAW Exception class.
Ideally, this can be caught to handle any exception from PRAW.
"""
class ClientException(PRAWException):
"""Base exception class for errors that don't involve the remote API."""
def __init__(self, message=None):
"""Construct a ClientException.
:param message: The error message to display.
"""
if not message:
message = 'Clientside error'
super(ClientException, self).__init__()
self.message = message
def __str__(self):
"""Return the message of the error."""
return self.message
class OAuthScopeRequired(ClientException):
"""Indicates that an OAuth2 scope is required to make the function call.
The attribute `scope` will contain the name of the necessary scope.
"""
def __init__(self, function, scope, message=None):
"""Contruct an OAuthScopeRequiredClientException.
:param function: The function that requires a scope.
:param scope: The scope required for the function.
:param message: A custom message to associate with the
exception. Default: `function` requires the OAuth2 scope `scope`
"""
if not message:
message = '`{0}` requires the OAuth2 scope `{1}`'.format(function,
scope)
super(OAuthScopeRequired, self).__init__(message)
self.scope = scope
class LoginRequired(ClientException):
"""Indicates that a logged in session is required.
This exception is raised on a preemptive basis, whereas NotLoggedIn occurs
in response to a lack of credentials on a privileged API call.
"""
def __init__(self, function, message=None):
"""Construct a LoginRequired exception.
:param function: The function that requires login-based authentication.
:param message: A custom message to associate with the exception.
Default: `function` requires a logged in session
"""
if not message:
message = '`{0}` requires a logged in session'.format(function)
super(LoginRequired, self).__init__(message)
class LoginOrScopeRequired(OAuthScopeRequired, LoginRequired):
"""Indicates that either a logged in session or OAuth2 scope is required.
The attribute `scope` will contain the name of the necessary scope.
"""
def __init__(self, function, scope, message=None):
"""Construct a LoginOrScopeRequired exception.
:param function: The function that requires authentication.
:param scope: The scope that is required if not logged in.
:param message: A custom message to associate with the exception.
Default: `function` requires a logged in session or the OAuth2
scope `scope`
"""
if not message:
message = ('`{0}` requires a logged in session or the '
'OAuth2 scope `{1}`').format(function, scope)
super(LoginOrScopeRequired, self).__init__(function, scope, message)
class ModeratorRequired(LoginRequired):
"""Indicates that a moderator of the subreddit is required."""
def __init__(self, function):
"""Construct a ModeratorRequired exception.
:param function: The function that requires moderator access.
"""
message = ('`{0}` requires a moderator '
'of the subreddit').format(function)
super(ModeratorRequired, self).__init__(message)
class ModeratorOrScopeRequired(LoginOrScopeRequired, ModeratorRequired):
"""Indicates that a moderator of the sub or OAuth2 scope is required.
The attribute `scope` will contain the name of the necessary scope.
"""
def __init__(self, function, scope):
"""Construct a ModeratorOrScopeRequired exception.
:param function: The function that requires moderator authentication or
a moderator scope..
:param scope: The scope that is required if not logged in with
moderator access..
"""
message = ('`{0}` requires a moderator of the subreddit or the '
'OAuth2 scope `{1}`').format(function, scope)
super(ModeratorOrScopeRequired, self).__init__(function, scope,
message)
class OAuthAppRequired(ClientException):
"""Raised when an OAuth client cannot be initialized.
This occurs when any one of the OAuth config values are not set.
"""
class HTTPException(PRAWException):
"""Base class for HTTP related exceptions."""
def __init__(self, _raw, message=None):
"""Construct a HTTPException.
:params _raw: The internal request library response object. This object
is mapped to attribute `_raw` whose format may change at any time.
"""
if not message:
message = 'HTTP error'
super(HTTPException, self).__init__()
self._raw = _raw
self.message = message
def __str__(self):
"""Return the message of the error."""
return self.message
class Forbidden(HTTPException):
"""Raised when the user does not have permission to the entity."""
class NotFound(HTTPException):
"""Raised when the requested entity is not found."""
class InvalidComment(PRAWException):
"""Indicate that the comment is no longer available on reddit."""
ERROR_TYPE = 'DELETED_COMMENT'
def __str__(self):
"""Return the message of the error."""
return self.ERROR_TYPE
class InvalidSubmission(PRAWException):
"""Indicates that the submission is no longer available on reddit."""
ERROR_TYPE = 'DELETED_LINK'
def __str__(self):
"""Return the message of the error."""
return self.ERROR_TYPE
class InvalidSubreddit(PRAWException):
"""Indicates that an invalid subreddit name was supplied."""
ERROR_TYPE = 'SUBREDDIT_NOEXIST'
def __str__(self):
"""Return the message of the error."""
return self.ERROR_TYPE
class RedirectException(PRAWException):
"""Raised when a redirect response occurs that is not expected."""
def __init__(self, request_url, response_url, message=None):
"""Construct a RedirectException.
:param request_url: The url requested.
:param response_url: The url being redirected to.
:param message: A custom message to associate with the exception.
"""
if not message:
message = ('Unexpected redirect '
'from {0} to {1}').format(request_url, response_url)
super(RedirectException, self).__init__()
self.request_url = request_url
self.response_url = response_url
self.message = message
def __str__(self):
"""Return the message of the error."""
return self.message
class OAuthException(PRAWException):
"""Base exception class for OAuth API calls.
Attribute `message` contains the error message.
Attribute `url` contains the url that resulted in the error.
"""
def __init__(self, message, url):
"""Construct a OAuthException.
:param message: The message associated with the exception.
:param url: The url that resulted in error.
"""
super(OAuthException, self).__init__()
self.message = message
self.url = url
def __str__(self):
"""Return the message along with the url."""
return self.message + " on url {0}".format(self.url)
class OAuthInsufficientScope(OAuthException):
"""Raised when the current OAuth scope is not sufficient for the action.
This indicates the access token is valid, but not for the desired action.
"""
class OAuthInvalidGrant(OAuthException):
"""Raised when the code to retrieve access information is not valid."""
class OAuthInvalidToken(OAuthException):
"""Raised when the current OAuth access token is not valid."""
class APIException(PRAWException):
"""Base exception class for the reddit API error message exceptions.
All exceptions of this type should have their own subclass.
"""
def __init__(self, error_type, message, field='', response=None):
"""Construct an APIException.
:param error_type: The error type set on reddit's end.
:param message: The associated message for the error.
:param field: The input field associated with the error, or ''.
:param response: The HTTP response that resulted in the exception.
"""
super(APIException, self).__init__()
self.error_type = error_type
self.message = message
self.field = field
self.response = response
def __str__(self):
"""Return a string containing the error message and field."""
if hasattr(self, 'ERROR_TYPE'):
return '`{0}` on field `{1}`'.format(self.message, self.field)
else:
return '({0}) `{1}` on field `{2}`'.format(self.error_type,
self.message,
self.field)
class ExceptionList(APIException):
"""Raised when more than one exception occurred."""
def __init__(self, errors):
"""Construct an ExceptionList.
:param errors: The list of errors.
"""
super(ExceptionList, self).__init__(None, None)
self.errors = errors
def __str__(self):
"""Return a string representation for all the errors."""
ret = '\n'
for i, error in enumerate(self.errors):
ret += '\tError {0}) {1}\n'.format(i, six.text_type(error))
return ret
class AlreadySubmitted(APIException):
"""An exception to indicate that a URL was previously submitted."""
ERROR_TYPE = 'ALREADY_SUB'
class AlreadyModerator(APIException):
"""Used to indicate that a user is already a moderator of a subreddit."""
ERROR_TYPE = 'ALREADY_MODERATOR'
class BadCSS(APIException):
"""An exception to indicate bad CSS (such as invalid) was used."""
ERROR_TYPE = 'BAD_CSS'
class BadCSSName(APIException):
"""An exception to indicate a bad CSS name (such as invalid) was used."""
ERROR_TYPE = 'BAD_CSS_NAME'
class BadUsername(APIException):
"""An exception to indicate an invalid username was used."""
ERROR_TYPE = 'BAD_USERNAME'
class InvalidCaptcha(APIException):
"""An exception for when an incorrect captcha error is returned."""
ERROR_TYPE = 'BAD_CAPTCHA'
class InvalidEmails(APIException):
"""An exception for when invalid emails are provided."""
ERROR_TYPE = 'BAD_EMAILS'
class InvalidFlairTarget(APIException):
"""An exception raised when an invalid user is passed as a flair target."""
ERROR_TYPE = 'BAD_FLAIR_TARGET'
class InvalidInvite(APIException):
"""Raised when attempting to accept a nonexistent moderator invite."""
ERROR_TYPE = 'NO_INVITE_FOUND'
class InvalidUser(APIException):
"""An exception for when a user doesn't exist."""
ERROR_TYPE = 'USER_DOESNT_EXIST'
class InvalidUserPass(APIException):
"""An exception for failed logins."""
ERROR_TYPE = 'WRONG_PASSWORD'
class InsufficientCreddits(APIException):
"""Raised when there are not enough creddits to complete the action."""
ERROR_TYPE = 'INSUFFICIENT_CREDDITS'
class NotLoggedIn(APIException):
"""An exception for when a Reddit user isn't logged in."""
ERROR_TYPE = 'USER_REQUIRED'
class NotModified(APIException):
"""An exception raised when reddit returns {'error': 304}.
This error indicates that the requested content was not modified and is
being requested too frequently. Such an error usually occurs when multiple
instances of PRAW are running concurrently or in rapid succession.
"""
def __init__(self, response):
"""Construct an instance of the NotModified exception.
This error does not have an error_type, message, nor field.
"""
super(NotModified, self).__init__(None, None, response=response)
def __str__(self):
"""Return: That page has not been modified."""
return 'That page has not been modified.'
class RateLimitExceeded(APIException):
"""An exception for when something has happened too frequently.
Contains a `sleep_time` attribute for the number of seconds that must
transpire prior to the next request.
"""
ERROR_TYPE = 'RATELIMIT'
def __init__(self, error_type, message, field, response):
"""Construct an instance of the RateLimitExceeded exception.
The parameters match that of :class:`APIException`.
The `sleep_time` attribute is extracted from the response object.
"""
super(RateLimitExceeded, self).__init__(error_type, message,
field, response)
self.sleep_time = self.response['ratelimit']
class SubredditExists(APIException):
"""An exception to indicate that a subreddit name is not available."""
ERROR_TYPE = 'SUBREDDIT_EXISTS'
class UsernameExists(APIException):
"""An exception to indicate that a username is not available."""
ERROR_TYPE = 'USERNAME_TAKEN'
def _build_error_mapping():
def predicate(obj):
return inspect.isclass(obj) and hasattr(obj, 'ERROR_TYPE')
tmp = {}
for _, obj in inspect.getmembers(sys.modules[__name__], predicate):
tmp[obj.ERROR_TYPE] = obj
return tmp
ERROR_MAPPING = _build_error_mapping()

View File

@@ -0,0 +1,243 @@
"""Provides classes that handle request dispatching."""
from __future__ import print_function, unicode_literals
import socket
import sys
import time
from functools import wraps
from praw.errors import ClientException
from praw.helpers import normalize_url
from requests import Session
from six import text_type
from six.moves import cPickle # pylint: disable=F0401
from threading import Lock
from timeit import default_timer as timer
class RateLimitHandler(object):
"""The base handler that provides thread-safe rate limiting enforcement.
While this handler is threadsafe, PRAW is not thread safe when the same
`Reddit` instance is being utilized from multiple threads.
"""
last_call = {} # Stores a two-item list: [lock, previous_call_time]
rl_lock = Lock() # lock used for adding items to last_call
@staticmethod
def rate_limit(function):
"""Return a decorator that enforces API request limit guidelines.
We are allowed to make a API request every api_request_delay seconds as
specified in praw.ini. This value may differ from reddit to reddit. For
reddit.com it is 2. Any function decorated with this will be forced to
delay _rate_delay seconds from the calling of the last function
decorated with this before executing.
This decorator must be applied to a RateLimitHandler class method or
instance method as it assumes `rl_lock` and `last_call` are available.
"""
@wraps(function)
def wrapped(cls, _rate_domain, _rate_delay, **kwargs):
cls.rl_lock.acquire()
lock_last = cls.last_call.setdefault(_rate_domain, [Lock(), 0])
with lock_last[0]: # Obtain the domain specific lock
cls.rl_lock.release()
# Sleep if necessary, then perform the request
now = timer()
delay = lock_last[1] + _rate_delay - now
if delay > 0:
now += delay
time.sleep(delay)
lock_last[1] = now
return function(cls, **kwargs)
return wrapped
@classmethod
def evict(cls, urls): # pylint: disable=W0613
"""Method utilized to evict entries for the given urls.
:param urls: An iterable containing normalized urls.
:returns: The number of items removed from the cache.
By default this method returns False as a cache need not be present.
"""
return 0
def __del__(self):
"""Cleanup the HTTP session."""
if self.http:
try:
self.http.close()
except: # Never fail pylint: disable=W0702
pass
def __init__(self):
"""Establish the HTTP session."""
self.http = Session() # Each instance should have its own session
def request(self, request, proxies, timeout, verify, **_):
"""Responsible for dispatching the request and returning the result.
Network level exceptions should be raised and only
``requests.Response`` should be returned.
:param request: A ``requests.PreparedRequest`` object containing all
the data necessary to perform the request.
:param proxies: A dictionary of proxy settings to be utilized for the
request.
:param timeout: Specifies the maximum time that the actual HTTP request
can take.
:param verify: Specifies if SSL certificates should be validated.
``**_`` should be added to the method call to ignore the extra
arguments intended for the cache handler.
"""
settings = self.http.merge_environment_settings(
request.url, proxies, False, verify, None
)
return self.http.send(request, timeout=timeout, allow_redirects=False,
**settings)
RateLimitHandler.request = RateLimitHandler.rate_limit(
RateLimitHandler.request)
class DefaultHandler(RateLimitHandler):
"""Extends the RateLimitHandler to add thread-safe caching support."""
ca_lock = Lock()
cache = {}
cache_hit_callback = None
timeouts = {}
@staticmethod
def with_cache(function):
"""Return a decorator that interacts with a handler's cache.
This decorator must be applied to a DefaultHandler class method or
instance method as it assumes `cache`, `ca_lock` and `timeouts` are
available.
"""
@wraps(function)
def wrapped(cls, _cache_key, _cache_ignore, _cache_timeout, **kwargs):
def clear_timeouts():
"""Clear the cache of timed out results."""
for key in list(cls.timeouts):
if timer() - cls.timeouts[key] > _cache_timeout:
del cls.timeouts[key]
del cls.cache[key]
if _cache_ignore:
return function(cls, **kwargs)
with cls.ca_lock:
clear_timeouts()
if _cache_key in cls.cache:
if cls.cache_hit_callback:
cls.cache_hit_callback(_cache_key)
return cls.cache[_cache_key]
# Releasing the lock before actually making the request allows for
# the possibility of more than one thread making the same request
# to get through. Without having domain-specific caching (under the
# assumption only one request to a domain can be made at a
# time), there isn't a better way to handle this.
result = function(cls, **kwargs)
# The handlers don't call `raise_for_status` so we need to ignore
# status codes that will result in an exception that should not be
# cached.
if result.status_code not in (200, 302):
return result
with cls.ca_lock:
cls.timeouts[_cache_key] = timer()
cls.cache[_cache_key] = result
return result
return wrapped
@classmethod
def clear_cache(cls):
"""Remove all items from the cache."""
with cls.ca_lock:
cls.cache = {}
cls.timeouts = {}
@classmethod
def evict(cls, urls):
"""Remove items from cache matching URLs.
Return the number of items removed.
"""
if isinstance(urls, text_type):
urls = [urls]
urls = set(normalize_url(url) for url in urls)
retval = 0
with cls.ca_lock:
for key in list(cls.cache):
if key[0] in urls:
retval += 1
del cls.cache[key]
del cls.timeouts[key]
return retval
DefaultHandler.request = DefaultHandler.with_cache(RateLimitHandler.request)
class MultiprocessHandler(object):
"""A PRAW handler to interact with the PRAW multi-process server."""
def __init__(self, host='localhost', port=10101):
"""Construct an instance of the MultiprocessHandler."""
self.host = host
self.port = port
def _relay(self, **kwargs):
"""Send the request through the server and return the HTTP response."""
retval = None
delay_time = 2 # For connection retries
read_attempts = 0 # For reading from socket
while retval is None: # Evict can return False
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock_fp = sock.makefile('rwb') # Used for pickle
try:
sock.connect((self.host, self.port))
cPickle.dump(kwargs, sock_fp, cPickle.HIGHEST_PROTOCOL)
sock_fp.flush()
retval = cPickle.load(sock_fp)
except: # pylint: disable=W0702
exc_type, exc, _ = sys.exc_info()
socket_error = exc_type is socket.error
if socket_error and exc.errno == 111: # Connection refused
sys.stderr.write('Cannot connect to multiprocess server. I'
's it running? Retrying in {0} seconds.\n'
.format(delay_time))
time.sleep(delay_time)
delay_time = min(64, delay_time * 2)
elif exc_type is EOFError or socket_error and exc.errno == 104:
# Failure during socket READ
if read_attempts >= 3:
raise ClientException('Successive failures reading '
'from the multiprocess server.')
sys.stderr.write('Lost connection with multiprocess server'
' during read. Trying again.\n')
read_attempts += 1
else:
raise
finally:
sock_fp.close()
sock.close()
if isinstance(retval, Exception):
raise retval # pylint: disable=E0702
return retval
def evict(self, urls):
"""Forward the eviction to the server and return its response."""
return self._relay(method='evict', urls=urls)
def request(self, **kwargs):
"""Forward the request to the server and return its HTTP response."""
return self._relay(method='request', **kwargs)

View File

@@ -0,0 +1,481 @@
# This file is part of PRAW.
#
# PRAW is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# PRAW is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# PRAW. If not, see <http://www.gnu.org/licenses/>.
"""
Helper functions.
The functions here provide functionality that is often needed by programs using
PRAW, but which isn't part of reddit's API.
"""
from __future__ import unicode_literals
import six
import sys
import time
from collections import deque
from functools import partial
from timeit import default_timer as timer
from praw.errors import HTTPException, PRAWException
from operator import attrgetter
BACKOFF_START = 4 # Minimum number of seconds to sleep during errors
KEEP_ITEMS = 128 # On each iteration only remember the first # items
# for conversion between broken reddit timestamps and unix timestamps
REDDIT_TIMESTAMP_OFFSET = 28800
def comment_stream(reddit_session, subreddit, limit=None, verbosity=1):
"""Indefinitely yield new comments from the provided subreddit.
Comments are yielded from oldest to newest.
:param reddit_session: The reddit_session to make requests from. In all the
examples this is assigned to the variable ``r``.
:param subreddit: Either a subreddit object, or the name of a
subreddit. Use `all` to get the comment stream for all comments made to
reddit.
:param limit: The maximum number of comments to fetch in a single
iteration. When None, fetch all available comments (reddit limits this
to 1000 (or multiple of 1000 for multi-subreddits). If this number is
too small, comments may be missed.
:param verbosity: A number that controls the amount of output produced to
stderr. <= 0: no output; >= 1: output the total number of comments
processed and provide the short-term number of comments processed per
second; >= 2: output when additional delays are added in order to avoid
subsequent unexpected http errors. >= 3: output debugging information
regarding the comment stream. (Default: 1)
"""
get_function = partial(reddit_session.get_comments,
six.text_type(subreddit))
return _stream_generator(get_function, limit, verbosity)
def submission_stream(reddit_session, subreddit, limit=None, verbosity=1):
"""Indefinitely yield new submissions from the provided subreddit.
Submissions are yielded from oldest to newest.
:param reddit_session: The reddit_session to make requests from. In all the
examples this is assigned to the variable ``r``.
:param subreddit: Either a subreddit object, or the name of a
subreddit. Use `all` to get the submissions stream for all submissions
made to reddit.
:param limit: The maximum number of submissions to fetch in a single
iteration. When None, fetch all available submissions (reddit limits
this to 1000 (or multiple of 1000 for multi-subreddits). If this number
is too small, submissions may be missed. Since there isn't a limit to
the number of submissions that can be retrieved from r/all, the limit
will be set to 1000 when limit is None.
:param verbosity: A number that controls the amount of output produced to
stderr. <= 0: no output; >= 1: output the total number of submissions
processed and provide the short-term number of submissions processed
per second; >= 2: output when additional delays are added in order to
avoid subsequent unexpected http errors. >= 3: output debugging
information regarding the submission stream. (Default: 1)
"""
if six.text_type(subreddit).lower() == "all":
if limit is None:
limit = 1000
if not hasattr(subreddit, 'reddit_session'):
subreddit = reddit_session.get_subreddit(subreddit)
return _stream_generator(subreddit.get_new, limit, verbosity)
def valid_redditors(redditors, sub):
"""Return a verified list of valid Redditor instances.
:param redditors: A list comprised of Redditor instances and/or strings
that are to be verified as actual redditor accounts.
:param sub: A Subreddit instance that the authenticated account has
flair changing permission on.
Note: Flair will be unset for all valid redditors in `redditors` on the
subreddit `sub`. A valid redditor is defined as a redditor that is
registered on reddit.
"""
simplified = list(set(six.text_type(x).lower() for x in redditors))
return [sub.reddit_session.get_redditor(simplified[i], fetch=False)
for (i, resp) in enumerate(sub.set_flair_csv(
({'user': x, 'flair_text': x} for x in simplified)))
if resp['ok']]
def submissions_between(reddit_session,
subreddit,
lowest_timestamp=None,
highest_timestamp=None,
newest_first=True,
extra_cloudsearch_fields=None,
verbosity=1):
"""Yield submissions between two timestamps.
If both ``highest_timestamp`` and ``lowest_timestamp`` are unspecified,
yields all submissions in the ``subreddit``.
Submissions are yielded from newest to oldest(like in the "new" queue).
:param reddit_session: The reddit_session to make requests from. In all the
examples this is assigned to the variable ``r``.
:param subreddit: Either a subreddit object, or the name of a
subreddit. Use `all` to get the submissions stream for all submissions
made to reddit.
:param lowest_timestamp: The lower bound for ``created_utc`` atributed of
submissions.
(Default: subreddit's created_utc or 0 when subreddit == "all").
:param highest_timestamp: The upper bound for ``created_utc`` attribute
of submissions. (Default: current unix time)
NOTE: both highest_timestamp and lowest_timestamp are proper
unix timestamps(just like ``created_utc`` attributes)
:param newest_first: If set to true, yields submissions
from newest to oldest. Otherwise yields submissions
from oldest to newest
:param extra_cloudsearch_fields: Allows extra filtering of results by
parameters like author, self. Full list is available here:
https://www.reddit.com/wiki/search
:param verbosity: A number that controls the amount of output produced to
stderr. <= 0: no output; >= 1: output the total number of submissions
processed; >= 2: output debugging information regarding
the search queries. (Default: 1)
"""
def debug(msg, level):
if verbosity >= level:
sys.stderr.write(msg + '\n')
def format_query_field(k, v):
if k in ["nsfw", "self"]:
# even though documentation lists "no" and "yes"
# as possible values, in reality they don't work
if v not in [0, 1, "0", "1"]:
raise PRAWException("Invalid value for the extra"
"field {}. Only '0' and '1' are"
"valid values.".format(k))
return "{}:{}".format(k, v)
return "{}:'{}'".format(k, v)
if extra_cloudsearch_fields is None:
extra_cloudsearch_fields = {}
extra_query_part = " ".join(
[format_query_field(k, v) for (k, v)
in sorted(extra_cloudsearch_fields.items())]
)
if highest_timestamp is None:
highest_timestamp = int(time.time()) + REDDIT_TIMESTAMP_OFFSET
else:
highest_timestamp = int(highest_timestamp) + REDDIT_TIMESTAMP_OFFSET
if lowest_timestamp is not None:
lowest_timestamp = int(lowest_timestamp) + REDDIT_TIMESTAMP_OFFSET
elif not isinstance(subreddit, six.string_types):
lowest_timestamp = int(subreddit.created)
elif subreddit not in ("all", "contrib", "mod", "friend"):
lowest_timestamp = int(reddit_session.get_subreddit(subreddit).created)
else:
lowest_timestamp = 0
original_highest_timestamp = highest_timestamp
original_lowest_timestamp = lowest_timestamp
# When making timestamp:X..Y queries, reddit misses submissions
# inside X..Y range, but they can be found inside Y..Z range
# It is not clear what is the value of Z should be, but it seems
# like the difference is usually about ~1 hour or less
# To be sure, let's set the workaround offset to 2 hours
out_of_order_submissions_workaround_offset = 7200
highest_timestamp += out_of_order_submissions_workaround_offset
lowest_timestamp -= out_of_order_submissions_workaround_offset
# Those parameters work ok, but there may be a better set of parameters
window_size = 60 * 60
search_limit = 100
min_search_results_in_window = 50
window_adjustment_ratio = 1.25
backoff = BACKOFF_START
processed_submissions = 0
prev_win_increased = False
prev_win_decreased = False
while highest_timestamp >= lowest_timestamp:
try:
if newest_first:
t1 = max(highest_timestamp - window_size, lowest_timestamp)
t2 = highest_timestamp
else:
t1 = lowest_timestamp
t2 = min(lowest_timestamp + window_size, highest_timestamp)
search_query = 'timestamp:{}..{}'.format(t1, t2)
if extra_query_part:
search_query = "(and {} {})".format(search_query,
extra_query_part)
debug(search_query, 3)
search_results = list(reddit_session.search(search_query,
subreddit=subreddit,
limit=search_limit,
syntax='cloudsearch',
sort='new'))
debug("Received {0} search results for query {1}"
.format(len(search_results), search_query),
2)
backoff = BACKOFF_START
except HTTPException as exc:
debug("{0}. Sleeping for {1} seconds".format(exc, backoff), 2)
time.sleep(backoff)
backoff *= 2
continue
if len(search_results) >= search_limit:
power = 2 if prev_win_decreased else 1
window_size = int(window_size / window_adjustment_ratio**power)
prev_win_decreased = True
debug("Decreasing window size to {0} seconds".format(window_size),
2)
# Since it is possible that there are more submissions
# in the current window, we have to re-do the request
# with reduced window
continue
else:
prev_win_decreased = False
search_results = [s for s in search_results
if original_lowest_timestamp <= s.created and
s.created <= original_highest_timestamp]
for submission in sorted(search_results,
key=attrgetter('created_utc', 'id'),
reverse=newest_first):
yield submission
processed_submissions += len(search_results)
debug('Total processed submissions: {}'
.format(processed_submissions), 1)
if newest_first:
highest_timestamp -= (window_size + 1)
else:
lowest_timestamp += (window_size + 1)
if len(search_results) < min_search_results_in_window:
power = 2 if prev_win_increased else 1
window_size = int(window_size * window_adjustment_ratio**power)
prev_win_increased = True
debug("Increasing window size to {0} seconds"
.format(window_size), 2)
else:
prev_win_increased = False
def _stream_generator(get_function, limit=None, verbosity=1):
def debug(msg, level):
if verbosity >= level:
sys.stderr.write(msg + '\n')
def b36_id(item):
return int(item.id, 36)
seen = BoundedSet(KEEP_ITEMS * 16)
before = None
count = 0 # Count is incremented to bypass the cache
processed = 0
backoff = BACKOFF_START
while True:
items = []
sleep = None
start = timer()
try:
i = None
params = {'uniq': count}
count = (count + 1) % 100
if before:
params['before'] = before
gen = enumerate(get_function(limit=limit, params=params))
for i, item in gen:
if b36_id(item) in seen:
if i == 0:
if before is not None:
# reddit sent us out of order data -- log it
debug('(INFO) {0} already seen with before of {1}'
.format(item.fullname, before), 3)
before = None
break
if i == 0: # Always the first item in the generator
before = item.fullname
if b36_id(item) not in seen:
items.append(item)
processed += 1
if verbosity >= 1 and processed % 100 == 0:
sys.stderr.write(' Items: {0} \r'
.format(processed))
sys.stderr.flush()
if i < KEEP_ITEMS:
seen.add(b36_id(item))
else: # Generator exhausted
if i is None: # Generator yielded no items
assert before is not None
# Try again without before as the before item may be too
# old or no longer exist.
before = None
backoff = BACKOFF_START
except HTTPException as exc:
sleep = (backoff, '{0}. Sleeping for {{0}} seconds.'.format(exc),
2)
backoff *= 2
# Provide rate limit
if verbosity >= 1:
rate = len(items) / (timer() - start)
sys.stderr.write(' Items: {0} ({1:.2f} ips) \r'
.format(processed, rate))
sys.stderr.flush()
# Yield items from oldest to newest
for item in items[::-1]:
yield item
# Sleep if necessary
if sleep:
sleep_time, msg, msg_level = sleep # pylint: disable=W0633
debug(msg.format(sleep_time), msg_level)
time.sleep(sleep_time)
def chunk_sequence(sequence, chunk_length, allow_incomplete=True):
"""Given a sequence, divide it into sequences of length `chunk_length`.
:param allow_incomplete: If True, allow final chunk to be shorter if the
given sequence is not an exact multiple of `chunk_length`.
If False, the incomplete chunk will be discarded.
"""
(complete, leftover) = divmod(len(sequence), chunk_length)
if not allow_incomplete:
leftover = 0
chunk_count = complete + min(leftover, 1)
chunks = []
for x in range(chunk_count):
left = chunk_length * x
right = left + chunk_length
chunks.append(sequence[left:right])
return chunks
def convert_id36_to_numeric_id(id36):
"""Convert strings representing base36 numbers into an integer."""
if not isinstance(id36, six.string_types) or id36.count("_") > 0:
raise ValueError("must supply base36 string, not fullname (e.g. use "
"xxxxx, not t3_xxxxx)")
return int(id36, 36)
def convert_numeric_id_to_id36(numeric_id):
"""Convert an integer into its base36 string representation.
This method has been cleaned up slightly to improve readability. For more
info see:
https://github.com/reddit/reddit/blob/master/r2/r2/lib/utils/_utils.pyx
https://www.reddit.com/r/redditdev/comments/n624n/submission_ids_question/
https://en.wikipedia.org/wiki/Base36
"""
# base36 allows negative numbers, but reddit does not
if not isinstance(numeric_id, six.integer_types) or numeric_id < 0:
raise ValueError("must supply a positive int/long")
# Alphabet used for base 36 conversion
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
alphabet_len = len(alphabet)
# Temp assign
current_number = numeric_id
base36 = []
# Current_number must be greater than alphabet length to while/divmod
if 0 <= current_number < alphabet_len:
return alphabet[current_number]
# Break up into chunks
while current_number != 0:
current_number, rem = divmod(current_number, alphabet_len)
base36.append(alphabet[rem])
# String is built in reverse order
return ''.join(reversed(base36))
def flatten_tree(tree, nested_attr='replies', depth_first=False):
"""Return a flattened version of the passed in tree.
:param nested_attr: The attribute name that contains the nested items.
Defaults to ``replies`` which is suitable for comments.
:param depth_first: When true, add to the list in a depth-first manner
rather than the default breadth-first manner.
"""
stack = deque(tree)
extend = stack.extend if depth_first else stack.extendleft
retval = []
while stack:
item = stack.popleft()
nested = getattr(item, nested_attr, None)
if nested:
extend(nested)
retval.append(item)
return retval
def normalize_url(url):
"""Return url after stripping trailing .json and trailing slashes."""
if url.endswith('.json'):
url = url[:-5]
if url.endswith('/'):
url = url[:-1]
return url
class BoundedSet(object):
"""A set with a maximum size that evicts the oldest items when necessary.
This class does not implement the complete set interface.
"""
def __init__(self, max_items):
"""Construct an instance of the BoundedSet."""
self.max_items = max_items
self._fifo = []
self._set = set()
def __contains__(self, item):
"""Test if the BoundedSet contains item."""
return item in self._set
def add(self, item):
"""Add an item to the set discarding the oldest item if necessary."""
if item in self._set:
self._fifo.remove(item)
elif len(self._set) == self.max_items:
self._set.remove(self._fifo.pop(0))
self._fifo.append(item)
self._set.add(item)

View File

@@ -0,0 +1,271 @@
# This file is part of PRAW.
#
# PRAW is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# PRAW is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# PRAW. If not, see <http://www.gnu.org/licenses/>.
"""Internal helper functions.
The functions in this module are not to be relied upon by third-parties.
"""
from __future__ import print_function, unicode_literals
import os
import re
import six
import sys
from requests import Request, codes, exceptions
from requests.compat import urljoin
from praw.decorators import restrict_access
from praw.errors import (ClientException, HTTPException, Forbidden, NotFound,
InvalidSubreddit, OAuthException,
OAuthInsufficientScope, OAuthInvalidToken,
RedirectException)
from warnings import warn
try:
from OpenSSL import __version__ as _opensslversion
_opensslversionlist = [int(minor) if minor.isdigit() else minor
for minor in _opensslversion.split('.')]
except ImportError:
_opensslversionlist = [0, 15]
MIN_PNG_SIZE = 67
MIN_JPEG_SIZE = 128
MAX_IMAGE_SIZE = 512000
JPEG_HEADER = b'\xff\xd8\xff'
PNG_HEADER = b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a'
RE_REDIRECT = re.compile('(rand(om|nsfw))|about/sticky')
def _get_redditor_listing(subpath=''):
"""Return function to generate Redditor listings."""
def _listing(self, sort='new', time='all', *args, **kwargs):
"""Return a get_content generator for some RedditContentObject type.
:param sort: Specify the sort order of the results if applicable
(one of ``'hot'``, ``'new'``, ``'top'``, ``'controversial'``).
:param time: Specify the time-period to return submissions if
applicable (one of ``'hour'``, ``'day'``, ``'week'``,
``'month'``, ``'year'``, ``'all'``).
The additional parameters are passed directly into
:meth:`.get_content`. Note: the `url` parameter cannot be altered.
"""
kwargs.setdefault('params', {})
kwargs['params'].setdefault('sort', sort)
kwargs['params'].setdefault('t', time)
url = urljoin(self._url, subpath) # pylint: disable=W0212
return self.reddit_session.get_content(url, *args, **kwargs)
return _listing
def _get_sorter(subpath='', **defaults):
"""Return function to generate specific subreddit Submission listings."""
@restrict_access(scope='read')
def _sorted(self, *args, **kwargs):
"""Return a get_content generator for some RedditContentObject type.
The additional parameters are passed directly into
:meth:`.get_content`. Note: the `url` parameter cannot be altered.
"""
if not kwargs.get('params'):
kwargs['params'] = {}
for key, value in six.iteritems(defaults):
kwargs['params'].setdefault(key, value)
url = urljoin(self._url, subpath) # pylint: disable=W0212
return self.reddit_session.get_content(url, *args, **kwargs)
return _sorted
def _image_type(image):
size = os.path.getsize(image.name)
if size < MIN_PNG_SIZE:
raise ClientException('png image is too small.')
if size > MAX_IMAGE_SIZE:
raise ClientException('`image` is too big. Max: {0} bytes'
.format(MAX_IMAGE_SIZE))
first_bytes = image.read(MIN_PNG_SIZE)
image.seek(0)
if first_bytes.startswith(PNG_HEADER):
return 'png'
elif first_bytes.startswith(JPEG_HEADER):
if size < MIN_JPEG_SIZE:
raise ClientException('jpeg image is too small.')
return 'jpg'
raise ClientException('`image` must be either jpg or png.')
def _modify_relationship(relationship, unlink=False, is_sub=False):
"""Return a function for relationship modification.
Used to support friending (user-to-user), as well as moderating,
contributor creating, and banning (user-to-subreddit).
"""
# The API uses friend and unfriend to manage all of these relationships.
url_key = 'unfriend' if unlink else 'friend'
if relationship == 'friend':
access = {'scope': None, 'login': True}
elif relationship == 'moderator':
access = {'scope': 'modothers'}
elif relationship in ['banned', 'contributor', 'muted']:
access = {'scope': 'modcontributors'}
elif relationship in ['wikibanned', 'wikicontributor']:
access = {'scope': ['modcontributors', 'modwiki']}
else:
access = {'scope': None, 'mod': True}
@restrict_access(**access)
def do_relationship(thing, user, **kwargs):
data = {'name': six.text_type(user),
'type': relationship}
data.update(kwargs)
if is_sub:
data['r'] = six.text_type(thing)
else:
data['container'] = thing.fullname
session = thing.reddit_session
if relationship == 'moderator':
session.evict(session.config['moderators'].format(
subreddit=six.text_type(thing)))
url = session.config[url_key]
return session.request_json(url, data=data)
return do_relationship
def _prepare_request(reddit_session, url, params, data, auth, files,
method=None):
"""Return a requests Request object that can be "prepared"."""
# Requests using OAuth for authorization must switch to using the oauth
# domain.
if getattr(reddit_session, '_use_oauth', False):
bearer = 'bearer {0}'.format(reddit_session.access_token)
headers = {'Authorization': bearer}
config = reddit_session.config
for prefix in (config.api_url, config.permalink_url):
if url.startswith(prefix):
if config.log_requests >= 1:
msg = 'substituting {0} for {1} in url\n'.format(
config.oauth_url, prefix)
sys.stderr.write(msg)
url = config.oauth_url + url[len(prefix):]
break
else:
headers = {}
headers.update(reddit_session.http.headers)
if method:
pass
elif data or files:
method = 'POST'
else:
method = 'GET'
# Log the request if logging is enabled
if reddit_session.config.log_requests >= 1:
sys.stderr.write('{0}: {1}\n'.format(method, url))
if reddit_session.config.log_requests >= 2:
if params:
sys.stderr.write('params: {0}\n'.format(params))
if data:
sys.stderr.write('data: {0}\n'.format(data))
if auth:
sys.stderr.write('auth: {0}\n'.format(auth))
# Prepare request
request = Request(method=method, url=url, headers=headers, params=params,
auth=auth, cookies=reddit_session.http.cookies)
if method == 'GET':
return request
# Most POST requests require adding `api_type` and `uh` to the data.
if data is True:
data = {}
if isinstance(data, dict):
if not auth:
data.setdefault('api_type', 'json')
if reddit_session.modhash:
data.setdefault('uh', reddit_session.modhash)
else:
request.headers.setdefault('Content-Type', 'application/json')
request.data = data
request.files = files
return request
def _raise_redirect_exceptions(response):
"""Return the new url or None if there are no redirects.
Raise exceptions if appropriate.
"""
if response.status_code not in [301, 302, 307]:
return None
new_url = urljoin(response.url, response.headers['location'])
if 'reddits/search' in new_url: # Handle non-existent subreddit
subreddit = new_url.rsplit('=', 1)[1]
raise InvalidSubreddit('`{0}` is not a valid subreddit'
.format(subreddit))
elif not RE_REDIRECT.search(response.url):
raise RedirectException(response.url, new_url)
return new_url
def _raise_response_exceptions(response):
"""Raise specific errors on some status codes."""
if not response.ok and 'www-authenticate' in response.headers:
msg = response.headers['www-authenticate']
if 'insufficient_scope' in msg:
raise OAuthInsufficientScope('insufficient_scope', response.url)
elif 'invalid_token' in msg:
raise OAuthInvalidToken('invalid_token', response.url)
else:
raise OAuthException(msg, response.url)
if response.status_code == codes.forbidden: # pylint: disable=E1101
raise Forbidden(_raw=response)
elif response.status_code == codes.not_found: # pylint: disable=E1101
raise NotFound(_raw=response)
else:
try:
response.raise_for_status() # These should all be directly mapped
except exceptions.HTTPError as exc:
raise HTTPException(_raw=exc.response)
def _to_reddit_list(arg):
"""Return an argument converted to a reddit-formatted list.
The returned format is a comma deliminated list. Each element is a string
representation of an object. Either given as a string or as an object that
is then converted to its string representation.
"""
if (isinstance(arg, six.string_types) or not (
hasattr(arg, "__getitem__") or hasattr(arg, "__iter__"))):
return six.text_type(arg)
else:
return ','.join(six.text_type(a) for a in arg)
def _warn_pyopenssl():
"""Warn the user against faulty versions of pyOpenSSL."""
if _opensslversionlist < [0, 15]: # versions >= 0.15 are fine
warn(RuntimeWarning(
"pyOpenSSL {0} may be incompatible with praw if validating"
"ssl certificates, which is on by default.\nSee https://"
"github.com/praw/pull/625 for more information".format(
_opensslversion)
))

View File

@@ -0,0 +1,102 @@
"""Provides a request server to be used with the multiprocess handler."""
from __future__ import print_function, unicode_literals
import socket
import sys
from optparse import OptionParser
from praw import __version__
from praw.handlers import DefaultHandler
from requests import Session
from six.moves import cPickle, socketserver # pylint: disable=F0401
from threading import Lock
class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
# pylint: disable=R0903,W0232
"""A TCP server that creates new threads per connection."""
allow_reuse_address = True
@staticmethod
def handle_error(_, client_addr):
"""Mute tracebacks of common errors."""
exc_type, exc_value, _ = sys.exc_info()
if exc_type is socket.error and exc_value[0] == 32:
pass
elif exc_type is cPickle.UnpicklingError:
sys.stderr.write('Invalid connection from {0}\n'
.format(client_addr[0]))
else:
raise
class RequestHandler(socketserver.StreamRequestHandler):
# pylint: disable=W0232
"""A class that handles incoming requests.
Requests to the same domain are cached and rate-limited.
"""
ca_lock = Lock() # lock around cache and timeouts
cache = {} # caches requests
http = Session() # used to make requests
last_call = {} # Stores a two-item list: [lock, previous_call_time]
rl_lock = Lock() # lock used for adding items to last_call
timeouts = {} # store the time items in cache were entered
do_evict = DefaultHandler.evict # Add in the evict method
@staticmethod
def cache_hit_callback(key):
"""Output when a cache hit occurs."""
print('HIT {0} {1}'.format('POST' if key[1][1] else 'GET', key[0]))
@DefaultHandler.with_cache
@DefaultHandler.rate_limit
def do_request(self, request, proxies, timeout, **_):
"""Dispatch the actual request and return the result."""
print('{0} {1}'.format(request.method, request.url))
response = self.http.send(request, proxies=proxies, timeout=timeout,
allow_redirects=False)
response.raw = None # Make pickleable
return response
def handle(self):
"""Parse the RPC, make the call, and pickle up the return value."""
data = cPickle.load(self.rfile) # pylint: disable=E1101
method = data.pop('method')
try:
retval = getattr(self, 'do_{0}'.format(method))(**data)
except Exception as e:
# All exceptions should be passed to the client
retval = e
cPickle.dump(retval, self.wfile, # pylint: disable=E1101
cPickle.HIGHEST_PROTOCOL)
def run():
"""The entry point from the praw-multiprocess utility."""
parser = OptionParser(version='%prog {0}'.format(__version__))
parser.add_option('-a', '--addr', default='localhost',
help=('The address or host to listen on. Specify -a '
'0.0.0.0 to listen on all addresses. '
'Default: localhost'))
parser.add_option('-p', '--port', type='int', default='10101',
help=('The port to listen for requests on. '
'Default: 10101'))
options, _ = parser.parse_args()
try:
server = ThreadingTCPServer((options.addr, options.port),
RequestHandler)
except (socket.error, socket.gaierror) as exc: # Handle bind errors
print(exc)
sys.exit(1)
print('Listening on {0} port {1}'.format(options.addr, options.port))
try:
server.serve_forever() # pylint: disable=E1101
except KeyboardInterrupt:
server.socket.close() # pylint: disable=E1101
RequestHandler.http.close()
print('Goodbye!')

2003
rtv/packages/praw/objects.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,79 @@
[DEFAULT]
# The domain name PRAW will use to interact with the reddit site via its API.
api_domain: api.reddit.com
# Time, a float, in seconds, required between calls. See:
# http://code.reddit.com/wiki/API
api_request_delay: 2.0
# A boolean to indicate whether or not to check for package updates.
check_for_updates: True
# Time, a float, in seconds, to save the results of a get/post request.
cache_timeout: 30
# Log the API calls
# 0: no logging
# 1: log only the request URIs
# 2: log the request URIs as well as any POST data
log_requests: 0
# The domain name PRAW will use for oauth-related requests.
oauth_domain: oauth.reddit.com
# Whether or not to use HTTPS for oauth connections. This should only be
# changed for development environments.
oauth_https: True
# OAuth grant type: either `authorization_code` or `password`
oauth_grant_type: authorization_code
# The maximum length of unicode representations of Comment, Message and
# Submission objects. This is mainly used to fit them within a terminal window
# line. A negative value means no limit.
output_chars_limit: 80
# The domain name PRAW will use when permalinks are requested.
permalink_domain: www.reddit.com
# The domain name to use for short urls.
short_domain: redd.it
# A boolean to indicate if json_dict, which contains the original API response,
# should be stored on every object in the json_dict attribute. Default is
# False as memory usage will double if enabled.
store_json_result: False
# Maximum time, a float, in seconds, before a single HTTP request times
# out. urllib2.URLError is raised upon timeout.
timeout: 45
# A boolean to indicate if SSL certificats should be validated. The
# default is True.
validate_certs: True
# Object to kind mappings
comment_kind: t1
message_kind: t4
redditor_kind: t2
submission_kind: t3
subreddit_kind: t5
[reddit]
# Uses the default settings
[reddit_oauth_test]
oauth_client_id: stJlUSUbPQe5lQ
oauth_client_secret: iU-LsOzyJH7BDVoq-qOWNEq2zuI
oauth_redirect_uri: https://127.0.0.1:65010/authorize_callback
[local_example]
api_domain: reddit.local
api_request_delay: 0
log_requests: 0
message_kind: t7
permalink_domain: reddit.local
short_domain:
submission_kind: t6
subreddit_kind: t5

View File

@@ -0,0 +1,45 @@
# This file is part of PRAW.
#
# PRAW is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# PRAW is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# PRAW. If not, see <http://www.gnu.org/licenses/>.
"""Provides the code to load PRAW's configuration file `praw.ini`."""
from __future__ import print_function, unicode_literals
import os
import sys
from six.moves import configparser
def _load_configuration():
"""Attempt to load settings from various praw.ini files."""
config = configparser.RawConfigParser()
module_dir = os.path.dirname(sys.modules[__name__].__file__)
if 'APPDATA' in os.environ: # Windows
os_config_path = os.environ['APPDATA']
elif 'XDG_CONFIG_HOME' in os.environ: # Modern Linux
os_config_path = os.environ['XDG_CONFIG_HOME']
elif 'HOME' in os.environ: # Legacy Linux
os_config_path = os.path.join(os.environ['HOME'], '.config')
else:
os_config_path = None
locations = [os.path.join(module_dir, 'praw.ini'), 'praw.ini']
if os_config_path is not None:
locations.insert(1, os.path.join(os_config_path, 'praw.ini'))
if not config.read(locations):
raise Exception('Could not find config file in any of: {0}'
.format(locations))
return config
CONFIG = _load_configuration()
del _load_configuration

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Update the project's bundled dependencies by downloading the git repository and
copying over the most recent commit.
"""
import os
import shutil
import subprocess
import tempfile
_filepath = os.path.dirname(os.path.relpath(__file__))
ROOT = os.path.abspath(os.path.join(_filepath, '..'))
PRAW_REPO = 'https://github.com/michael-lazar/praw3.git'
def main():
tmpdir = tempfile.mkdtemp()
subprocess.check_call(['git', 'clone', PRAW_REPO, tmpdir])
# Update the commit hash reference
os.chdir(tmpdir)
p = subprocess.Popen(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE)
p.wait()
commit = p.stdout.read().strip()
print('Found commit %s' % commit)
regex = '"s/^__praw_hash__ =.*$/__praw_hash__ = \'%s\'/g"' % commit
packages_root = os.path.join(ROOT, 'rtv', 'packages', '__init__.py')
print('Updating commit hash in %s' % packages_root)
subprocess.check_call(['sed', '-i', '', regex, packages_root])
# Overwrite the project files
src = os.path.join(tmpdir, 'praw')
dest = os.path.join(ROOT, 'packages', 'praw')
print('Copying package files to %s' % dest)
shutil.rmtree(dest, ignore_errors=True)
shutil.copytree(src, dest)
# Cleanup
print('Removing directory %s' % tmpdir)
shutil.rmtree(tmpdir)
if __name__ == '__main__':
main()

View File

@@ -3,9 +3,9 @@ universal = 1
[metadata]
requires-dist =
praw>=3.5,<4
six
requests>=2.4.0
kitchen
beautifulsoup4
decorator
kitchen
mailcap-fix
requests>=2.4.0
six

View File

@@ -1,4 +1,3 @@
import sys
import setuptools
from version import __version__ as version
@@ -18,12 +17,13 @@ setuptools.setup(
package_data={'rtv': ['templates/*']},
data_files=[("share/man/man1", ["rtv.1"])],
install_requires=[
'praw >=3.5, <4',
'six',
'requests >=2.4.0',
'kitchen',
'beautifulsoup4',
'decorator',
'kitchen',
'mailcap-fix',
# For info on why this is pinned, see https://github.com/michael-lazar/rtv/issues/325
'requests >=2.4.0',
'six',
],
entry_points={'console_scripts': ['rtv=rtv.__main__:main']},
classifiers=[

View File

@@ -7,7 +7,6 @@ import logging
import threading
from functools import partial
import praw
import pytest
from vcr import VCR
from six.moves.urllib.parse import urlparse, parse_qs
@@ -15,6 +14,7 @@ from six.moves.BaseHTTPServer import HTTPServer
from rtv.oauth import OAuthHelper, OAuthHandler
from rtv.config import Config
from rtv.packages import praw
from rtv.terminal import Terminal
from rtv.subreddit_page import SubredditPage
from rtv.submission_page import SubmissionPage

View File

@@ -6,12 +6,12 @@ from itertools import islice
from collections import OrderedDict
import six
import praw
import pytest
from rtv import exceptions
from rtv.packages import praw
from rtv.content import (
Content, SubmissionContent, SubredditContent, SubscriptionContent)
from rtv import exceptions
try:
from unittest import mock

View File

@@ -2,9 +2,9 @@
from __future__ import unicode_literals
import requests
from praw.errors import OAuthException
from rtv.oauth import OAuthHelper, OAuthHandler
from rtv.packages.praw.errors import OAuthException
try:

View File

@@ -2,10 +2,10 @@
from __future__ import unicode_literals
import six
from praw.errors import NotFound
from rtv.subreddit_page import SubredditPage
from rtv import __version__
from rtv.subreddit_page import SubredditPage
from rtv.packages.praw.errors import NotFound
try:
from unittest import mock

View File

@@ -3,7 +3,6 @@ from __future__ import unicode_literals
import curses
import praw
import pytest
from rtv.subscription_page import SubscriptionPage

View File

@@ -10,8 +10,7 @@ import six
import pytest
from rtv.docs import HELP, COMMENT_EDIT_FILE
from rtv.objects import Color
from rtv.exceptions import TemporaryFileError, MailcapEntryNotFound
from rtv.exceptions import TemporaryFileError
try:
from unittest import mock