Merge pull request #423 from michael-lazar/filter_nsfw
Filter nsfw content when safe search is enabled
This commit is contained in:
@@ -437,12 +437,13 @@ class SubredditContent(Content):
|
||||
"""
|
||||
|
||||
def __init__(self, name, submissions, loader, order=None,
|
||||
max_title_rows=4, query=None):
|
||||
max_title_rows=4, query=None, filter_nsfw=False):
|
||||
|
||||
self.name = name
|
||||
self.order = order
|
||||
self.query = query
|
||||
self.max_title_rows = max_title_rows
|
||||
self.filter_nsfw = filter_nsfw
|
||||
self._loader = loader
|
||||
self._submissions = submissions
|
||||
self._submission_data = []
|
||||
@@ -605,9 +606,11 @@ class SubredditContent(Content):
|
||||
# display name with the one returned by the request.
|
||||
display_name = '/r/{0}'.format(subreddit.display_name)
|
||||
|
||||
filter_nsfw = (reddit.user and reddit.user.over_18 is False)
|
||||
|
||||
# We made it!
|
||||
return cls(display_name, submissions, loader, order=display_order,
|
||||
query=query)
|
||||
query=query, filter_nsfw=filter_nsfw)
|
||||
|
||||
@property
|
||||
def range(self):
|
||||
@@ -625,6 +628,7 @@ class SubredditContent(Content):
|
||||
if index < 0:
|
||||
raise IndexError
|
||||
|
||||
nsfw_count = 0
|
||||
while index >= len(self._submission_data):
|
||||
try:
|
||||
with self._loader('Loading more submissions'):
|
||||
@@ -634,6 +638,21 @@ class SubredditContent(Content):
|
||||
except StopIteration:
|
||||
raise IndexError
|
||||
else:
|
||||
|
||||
# Skip NSFW posts based on the reddit user's profile settings.
|
||||
# If we only see 20+ NSFW posts, stop looping and bail out.
|
||||
# This allows us to skip making an additional API call to check
|
||||
# if a subreddit is over18 (which doesn't work for things like
|
||||
# multireddits anyway)
|
||||
if self.filter_nsfw and submission.over_18:
|
||||
nsfw_count += 1
|
||||
if not self._submission_data and nsfw_count >= 20:
|
||||
raise exceptions.SubredditError(
|
||||
'You must be over 18+ to view this subreddit')
|
||||
continue
|
||||
else:
|
||||
nsfw_count = 0
|
||||
|
||||
if hasattr(submission, 'title'):
|
||||
data = self.strip_praw_submission(submission)
|
||||
else:
|
||||
|
||||
@@ -22,6 +22,10 @@ class SubmissionError(RTVError):
|
||||
"Submission could not be loaded"
|
||||
|
||||
|
||||
class SubredditError(RTVError):
|
||||
"Subreddit could not be loaded"
|
||||
|
||||
|
||||
class NoSubmissionsError(RTVError):
|
||||
"No submissions for the given page"
|
||||
|
||||
|
||||
4987
tests/cassettes/test_content_subreddit_nsfw_filter.yaml
Normal file
4987
tests/cassettes/test_content_subreddit_nsfw_filter.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -469,6 +469,41 @@ def test_content_subreddit_me(reddit, oauth, refresh_token, terminal):
|
||||
assert terminal.loader.exception.name == '/u/me'
|
||||
|
||||
|
||||
def test_content_subreddit_nsfw_filter(reddit, oauth, refresh_token, terminal):
|
||||
|
||||
# NSFW subreddits should load if not logged in
|
||||
name = '/r/ImGoingToHellForThis'
|
||||
SubredditContent.from_name(reddit, name, terminal.loader)
|
||||
|
||||
# Log in
|
||||
oauth.config.refresh_token = refresh_token
|
||||
oauth.authorize()
|
||||
|
||||
# Make sure the API parameter hasn't changed
|
||||
assert reddit.user.over_18 is not None
|
||||
|
||||
# Turn on safe search
|
||||
reddit.user.over_18 = False
|
||||
|
||||
# Should refuse to load this subreddit
|
||||
with pytest.raises(exceptions.SubredditError):
|
||||
name = '/r/ImGoingToHellForThis'
|
||||
SubredditContent.from_name(reddit, name, terminal.loader)
|
||||
|
||||
# Should filter out all of the nsfw posts
|
||||
name = '/r/ImGoingToHellForThis+python'
|
||||
content = SubredditContent.from_name(reddit, name, terminal.loader)
|
||||
for data in islice(content.iterate(0, 1), 50):
|
||||
assert data['object'].over_18 is False
|
||||
|
||||
# Turn off safe search
|
||||
reddit.user.over_18 = True
|
||||
|
||||
# The NSFW subreddit should load now
|
||||
name = '/r/ImGoingToHellForThis'
|
||||
SubredditContent.from_name(reddit, name, terminal.loader)
|
||||
|
||||
|
||||
def test_content_subscription(reddit, terminal):
|
||||
|
||||
# Not logged in
|
||||
|
||||
Reference in New Issue
Block a user