Filter nsfw content when safe search is enabled
This commit is contained in:
@@ -437,12 +437,13 @@ class SubredditContent(Content):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, submissions, loader, order=None,
|
def __init__(self, name, submissions, loader, order=None,
|
||||||
max_title_rows=4, query=None):
|
max_title_rows=4, query=None, filter_nsfw=False):
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.order = order
|
self.order = order
|
||||||
self.query = query
|
self.query = query
|
||||||
self.max_title_rows = max_title_rows
|
self.max_title_rows = max_title_rows
|
||||||
|
self.filter_nsfw = filter_nsfw
|
||||||
self._loader = loader
|
self._loader = loader
|
||||||
self._submissions = submissions
|
self._submissions = submissions
|
||||||
self._submission_data = []
|
self._submission_data = []
|
||||||
@@ -605,9 +606,11 @@ class SubredditContent(Content):
|
|||||||
# display name with the one returned by the request.
|
# display name with the one returned by the request.
|
||||||
display_name = '/r/{0}'.format(subreddit.display_name)
|
display_name = '/r/{0}'.format(subreddit.display_name)
|
||||||
|
|
||||||
|
filter_nsfw = (reddit.user and reddit.user.over_18 is False)
|
||||||
|
|
||||||
# We made it!
|
# We made it!
|
||||||
return cls(display_name, submissions, loader, order=display_order,
|
return cls(display_name, submissions, loader, order=display_order,
|
||||||
query=query)
|
query=query, filter_nsfw=filter_nsfw)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def range(self):
|
def range(self):
|
||||||
@@ -625,6 +628,7 @@ class SubredditContent(Content):
|
|||||||
if index < 0:
|
if index < 0:
|
||||||
raise IndexError
|
raise IndexError
|
||||||
|
|
||||||
|
nsfw_count = 0
|
||||||
while index >= len(self._submission_data):
|
while index >= len(self._submission_data):
|
||||||
try:
|
try:
|
||||||
with self._loader('Loading more submissions'):
|
with self._loader('Loading more submissions'):
|
||||||
@@ -634,6 +638,21 @@ class SubredditContent(Content):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise IndexError
|
raise IndexError
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
# Skip NSFW posts based on the reddit user's profile settings.
|
||||||
|
# If we only see 20+ NSFW posts, stop looping and bail out.
|
||||||
|
# This allows us to skip making an additional API call to check
|
||||||
|
# if a subreddit is over18 (which doesn't work for things like
|
||||||
|
# multireddits anyway)
|
||||||
|
if self.filter_nsfw and submission.over_18:
|
||||||
|
nsfw_count += 1
|
||||||
|
if not self._submission_data and nsfw_count >= 20:
|
||||||
|
raise exceptions.SubredditError(
|
||||||
|
'You must be over 18+ to view this subreddit')
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
nsfw_count = 0
|
||||||
|
|
||||||
if hasattr(submission, 'title'):
|
if hasattr(submission, 'title'):
|
||||||
data = self.strip_praw_submission(submission)
|
data = self.strip_praw_submission(submission)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ class SubmissionError(RTVError):
|
|||||||
"Submission could not be loaded"
|
"Submission could not be loaded"
|
||||||
|
|
||||||
|
|
||||||
|
class SubredditError(RTVError):
|
||||||
|
"Subreddit could not be loaded"
|
||||||
|
|
||||||
|
|
||||||
class NoSubmissionsError(RTVError):
|
class NoSubmissionsError(RTVError):
|
||||||
"No submissions for the given page"
|
"No submissions for the given page"
|
||||||
|
|
||||||
|
|||||||
4987
tests/cassettes/test_content_subreddit_nsfw_filter.yaml
Normal file
4987
tests/cassettes/test_content_subreddit_nsfw_filter.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -469,6 +469,41 @@ def test_content_subreddit_me(reddit, oauth, refresh_token, terminal):
|
|||||||
assert terminal.loader.exception.name == '/u/me'
|
assert terminal.loader.exception.name == '/u/me'
|
||||||
|
|
||||||
|
|
||||||
|
def test_content_subreddit_nsfw_filter(reddit, oauth, refresh_token, terminal):
|
||||||
|
|
||||||
|
# NSFW subreddits should load if not logged in
|
||||||
|
name = '/r/ImGoingToHellForThis'
|
||||||
|
SubredditContent.from_name(reddit, name, terminal.loader)
|
||||||
|
|
||||||
|
# Log in
|
||||||
|
oauth.config.refresh_token = refresh_token
|
||||||
|
oauth.authorize()
|
||||||
|
|
||||||
|
# Make sure the API parameter hasn't changed
|
||||||
|
assert reddit.user.over_18 is not None
|
||||||
|
|
||||||
|
# Turn on safe search
|
||||||
|
reddit.user.over_18 = False
|
||||||
|
|
||||||
|
# Should refuse to load this subreddit
|
||||||
|
with pytest.raises(exceptions.SubredditError):
|
||||||
|
name = '/r/ImGoingToHellForThis'
|
||||||
|
SubredditContent.from_name(reddit, name, terminal.loader)
|
||||||
|
|
||||||
|
# Should filter out all of the nsfw posts
|
||||||
|
name = '/r/ImGoingToHellForThis+python'
|
||||||
|
content = SubredditContent.from_name(reddit, name, terminal.loader)
|
||||||
|
for data in islice(content.iterate(0, 1), 50):
|
||||||
|
assert data['object'].over_18 is False
|
||||||
|
|
||||||
|
# Turn off safe search
|
||||||
|
reddit.user.over_18 = True
|
||||||
|
|
||||||
|
# The NSFW subreddit should load now
|
||||||
|
name = '/r/ImGoingToHellForThis'
|
||||||
|
SubredditContent.from_name(reddit, name, terminal.loader)
|
||||||
|
|
||||||
|
|
||||||
def test_content_subscription(reddit, terminal):
|
def test_content_subscription(reddit, terminal):
|
||||||
|
|
||||||
# Not logged in
|
# Not logged in
|
||||||
|
|||||||
Reference in New Issue
Block a user