Skip to content
Snippets Groups Projects
Commit f3cae5a7 authored by Dan Michael O. Heggø's avatar Dan Michael O. Heggø
Browse files

[#125] Send assert=true when force_login=True

- When `Site.force_login` is True we can send `assert=true`
  with `edit` actions to have the api reject our edits if
  we got logged out for some reason.
- Adding a new exception `AssertUserFailedError` as a subclass
  of `LoginError`.
parent ac5bcfed
No related branches found
No related tags found
No related merge requests found
......@@ -56,6 +56,20 @@ class OAuthAuthorizationError(LoginError):
return self.info
class AssertUserFailedError(LoginError):
def __init__(self):
self.message = 'By default, mwclient protects you from ' + \
'accidentally editing without being logged in. If you ' + \
'actually want to edit without logging in, you can set ' + \
'force_login on the Site object to False.'
LoginError.__init__(self)
def __str__(self):
return self.message
class EmailError(MwClientError):
pass
......
......@@ -173,14 +173,7 @@ class Page(object):
"""Update the text of a section or the whole page by performing an edit operation.
"""
if not self.site.logged_in and self.site.force_login:
# Should we really check for this?
raise mwclient.errors.LoginError(
self.site,
'By default, mwclient protects you from accidentally editing '
'without being logged in. '
'If you actually want to edit without logging in, '
'you can set force_login on the Site object to False.'
)
raise mwclient.errors.AssertUserFailedError()
if self.site.blocked:
raise mwclient.errors.UserBlocked(self.site.blocked)
if not self.can('edit'):
......@@ -205,6 +198,9 @@ class Page(object):
data.update(kwargs)
if self.site.force_login:
data['assert'] = 'user'
def do_edit():
result = self.site.post('edit', title=self.name, text=text,
summary=summary, token=self.get_token('edit'),
......@@ -240,6 +236,8 @@ class Page(object):
'noimageredirect-anon', 'noimageredirect', 'noedit-anon',
'noedit'}:
raise mwclient.errors.ProtectedPageError(self, e.code, e.info)
elif e.code == 'assertuserfailed':
raise mwclient.errors.AssertUserFailedError()
else:
raise
......
......@@ -10,6 +10,7 @@ import mock
import mwclient
from mwclient.page import Page
from mwclient.client import Site
from mwclient.errors import APIError, AssertUserFailedError
try:
import json
......@@ -215,8 +216,11 @@ class TestPageApiArgs(unittest.TestCase):
'ns': 0, 'pageid': 2, 'revisions': [{'*': 'Hello world', 'timestamp': '2014-08-29T22:25:15Z'}], 'title': title
}}}}
def get_last_api_call_args(self):
args, kwargs = self.site.get.call_args
def get_last_api_call_args(self, http_method='POST'):
if http_method == 'GET':
args, kwargs = self.site.get.call_args
else:
args, kwargs = self.site.post.call_args
action = args[0]
args = args[1:]
kwargs.update(args)
......@@ -228,7 +232,7 @@ class TestPageApiArgs(unittest.TestCase):
def test_get_page_text(self):
# Check that page.text() works, and that a correct API call is made
text = self.page.text()
args = self.get_last_api_call_args()
args = self.get_last_api_call_args(http_method='GET')
assert text == self.page_text
assert args == {
......@@ -253,17 +257,56 @@ class TestPageApiArgs(unittest.TestCase):
def test_get_section_text(self):
# Check that the 'rvsection' parameter is sent to the API
text = self.page.text(section=0)
args = self.get_last_api_call_args()
args = self.get_last_api_call_args(http_method='GET')
assert args['rvsection'] == '0'
def test_get_text_expanded(self):
# Check that the 'rvexpandtemplates' parameter is sent to the API
text = self.page.text(expandtemplates=True)
args = self.get_last_api_call_args()
args = self.get_last_api_call_args(http_method='GET')
assert args['rvexpandtemplates'] == '1'
def test_assertuser_true(self):
# Check that assert=user is sent when force_login=True
self.site.blocked = False
self.site.rights = ['read', 'edit']
self.site.logged_in = True
self.site.force_login = True
self.site.api.return_value = {
'edit': {'result': 'Ok'}
}
self.page.save('Some text')
args = self.get_last_api_call_args()
assert args['assert'] == 'user'
def test_assertuser_false(self):
# Check that assert=user is not sent when force_login=False
self.site.blocked = False
self.site.rights = ['read', 'edit']
self.site.logged_in = False
self.site.force_login = False
self.site.api.return_value = {
'edit': {'result': 'Ok'}
}
self.page.save('Some text')
args = self.get_last_api_call_args()
assert 'assert' not in args
def test_handle_edit_error_assertuserfailed(self):
# Check that AssertUserFailedError is triggered
api_error = APIError('assertuserfailed',
'Assertion that the user is logged in failed',
'See https://en.wikipedia.org/w/api.php for API usage')
with pytest.raises(AssertUserFailedError):
self.page.handle_edit_error(api_error, 'n/a')
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment