diff --git a/mwclient/errors.py b/mwclient/errors.py index 3dcc83a6044934789b96691da4f8006dd071e891..4971afb483cde23599423483c5c9b41bcc1772ac 100644 --- a/mwclient/errors.py +++ b/mwclient/errors.py @@ -56,6 +56,20 @@ class OAuthAuthorizationError(LoginError): return self.info +class AssertUserFailedError(LoginError): + + def __init__(self): + self.message = 'By default, mwclient protects you from ' + \ + 'accidentally editing without being logged in. If you ' + \ + 'actually want to edit without logging in, you can set ' + \ + 'force_login on the Site object to False.' + + LoginError.__init__(self) + + def __str__(self): + return self.message + + class EmailError(MwClientError): pass diff --git a/mwclient/page.py b/mwclient/page.py index 54c763ce07ea10b5144af67e14d5c05244755306..ae495baa17509566b71d4029ee248a47cbf272c9 100644 --- a/mwclient/page.py +++ b/mwclient/page.py @@ -173,14 +173,7 @@ class Page(object): """Update the text of a section or the whole page by performing an edit operation. """ if not self.site.logged_in and self.site.force_login: - # Should we really check for this? - raise mwclient.errors.LoginError( - self.site, - 'By default, mwclient protects you from accidentally editing ' - 'without being logged in. ' - 'If you actually want to edit without logging in, ' - 'you can set force_login on the Site object to False.' - ) + raise mwclient.errors.AssertUserFailedError() if self.site.blocked: raise mwclient.errors.UserBlocked(self.site.blocked) if not self.can('edit'): @@ -205,6 +198,9 @@ class Page(object): data.update(kwargs) + if self.site.force_login: + data['assert'] = 'user' + def do_edit(): result = self.site.post('edit', title=self.name, text=text, summary=summary, token=self.get_token('edit'), @@ -240,6 +236,8 @@ class Page(object): 'noimageredirect-anon', 'noimageredirect', 'noedit-anon', 'noedit'}: raise mwclient.errors.ProtectedPageError(self, e.code, e.info) + elif e.code == 'assertuserfailed': + raise mwclient.errors.AssertUserFailedError() else: raise diff --git a/tests/test_page.py b/tests/test_page.py index d7babe51b678b5e2ba3e8f2848721c4ddce34e28..45a4edd844dedd513b01ecd48c36d945415c8f78 100644 --- a/tests/test_page.py +++ b/tests/test_page.py @@ -10,6 +10,7 @@ import mock import mwclient from mwclient.page import Page from mwclient.client import Site +from mwclient.errors import APIError, AssertUserFailedError try: import json @@ -215,8 +216,11 @@ class TestPageApiArgs(unittest.TestCase): 'ns': 0, 'pageid': 2, 'revisions': [{'*': 'Hello world', 'timestamp': '2014-08-29T22:25:15Z'}], 'title': title }}}} - def get_last_api_call_args(self): - args, kwargs = self.site.get.call_args + def get_last_api_call_args(self, http_method='POST'): + if http_method == 'GET': + args, kwargs = self.site.get.call_args + else: + args, kwargs = self.site.post.call_args action = args[0] args = args[1:] kwargs.update(args) @@ -228,7 +232,7 @@ class TestPageApiArgs(unittest.TestCase): def test_get_page_text(self): # Check that page.text() works, and that a correct API call is made text = self.page.text() - args = self.get_last_api_call_args() + args = self.get_last_api_call_args(http_method='GET') assert text == self.page_text assert args == { @@ -253,17 +257,56 @@ class TestPageApiArgs(unittest.TestCase): def test_get_section_text(self): # Check that the 'rvsection' parameter is sent to the API text = self.page.text(section=0) - args = self.get_last_api_call_args() + args = self.get_last_api_call_args(http_method='GET') assert args['rvsection'] == '0' def test_get_text_expanded(self): # Check that the 'rvexpandtemplates' parameter is sent to the API text = self.page.text(expandtemplates=True) - args = self.get_last_api_call_args() + args = self.get_last_api_call_args(http_method='GET') assert args['rvexpandtemplates'] == '1' + def test_assertuser_true(self): + # Check that assert=user is sent when force_login=True + self.site.blocked = False + self.site.rights = ['read', 'edit'] + self.site.logged_in = True + self.site.force_login = True + + self.site.api.return_value = { + 'edit': {'result': 'Ok'} + } + self.page.save('Some text') + args = self.get_last_api_call_args() + + assert args['assert'] == 'user' + + def test_assertuser_false(self): + # Check that assert=user is not sent when force_login=False + self.site.blocked = False + self.site.rights = ['read', 'edit'] + self.site.logged_in = False + self.site.force_login = False + + self.site.api.return_value = { + 'edit': {'result': 'Ok'} + } + self.page.save('Some text') + args = self.get_last_api_call_args() + + assert 'assert' not in args + + def test_handle_edit_error_assertuserfailed(self): + # Check that AssertUserFailedError is triggered + api_error = APIError('assertuserfailed', + 'Assertion that the user is logged in failed', + 'See https://en.wikipedia.org/w/api.php for API usage') + + with pytest.raises(AssertUserFailedError): + self.page.handle_edit_error(api_error, 'n/a') + if __name__ == '__main__': unittest.main()