From 3505e2eb505ce8481340e0ebe69ff469b8d82f4b Mon Sep 17 00:00:00 2001
From: Dylann Cordel <cordel.d@free.fr>
Date: Wed, 5 Jul 2023 14:24:09 +0200
Subject: [PATCH] add methods to manage users

* simple user creation
* block user
* unblock user
* get basic info about user
* get groups
* add groups
* remove groups
* set groups (get current ones and add/remove what is needed)
---
 mwclient/client.py  | 367 +++++++++++++++++++++++++++-
 mwclient/errors.py  |   8 +
 test/test_client.py | 575 +++++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 936 insertions(+), 14 deletions(-)

diff --git a/mwclient/client.py b/mwclient/client.py
index bd1f715..b395d29 100644
--- a/mwclient/client.py
+++ b/mwclient/client.py
@@ -1,11 +1,10 @@
-import warnings
+import json
 import logging
-
+import warnings
 from collections import OrderedDict
 
-import json
 import requests
-from requests.auth import HTTPBasicAuth, AuthBase
+from requests.auth import AuthBase, HTTPBasicAuth
 from requests_oauthlib import OAuth1
 
 import mwclient.errors as errors
@@ -87,6 +86,9 @@ class Site:
             (where e is the exception object) and will be one of the API:Login errors. The
             most common error code is "Failed", indicating a wrong username or password.
     """
+    AVAILABLE_TOKEN_TYPES = {
+        'createaccount', 'csrf', 'login', 'patrol', 'rollback', 'userrights', 'watch'
+    }
     api_limit = 500
 
     def __init__(self, host, path='/w/', ext='.php', pool=None, retry_timeout=30,
@@ -840,7 +842,7 @@ class Site:
         if self.version is None or self.version[:2] >= (1, 24):
             # The 'csrf' (cross-site request forgery) token introduced in 1.24 replaces
             # the majority of older tokens, like edittoken and movetoken.
-            if type not in {'watch', 'patrol', 'rollback', 'userrights', 'login'}:
+            if type not in self.AVAILABLE_TOKEN_TYPES:
                 type = 'csrf'
 
         if type not in self.tokens:
@@ -873,6 +875,361 @@ class Site:
 
         return self.tokens[type]
 
+    def create_user(self, username, password, **kwargs):
+        """
+        Creates user on the wiki with at least a username and a password.
+
+            >>> try:
+            ...     exists = site.create_user('SomeUser', 'Some password',
+            ...                               **my_extra_kwargs)
+            ... except mwclient.errors.CreateError as e:
+            ...     exists = e.code == 'userexists':
+            ...     if not exists:
+            ...         print('Can not create user: %s' % e)
+            ... if exists:
+            ...     user = site.get_user('SomeUser')
+
+        Args:
+            username (str): User name of the user to create
+            password (str): User password of the user to create
+            **kwargs (dict): Other required fields depending on your mediawiki conf
+                             eg: an email
+
+        Returns:
+            Bool (True) if user has beend created, else raise UserCreateError
+
+        Raises:
+            UserCreateError (mwclient.errors.UserCreateError): User can not be created
+                                                               for some reason
+            APIError (mwclient.errors.APIError): Other API errors
+        """
+
+        if 'createtoken' not in kwargs:
+            kwargs['createtoken'] = self.get_token('createaccount')
+        if 'retype' not in kwargs:
+            kwargs['retype'] = password
+        if 'continue' not in kwargs and 'createreturnurl' not in kwargs:
+            # should be great if API didn't require this...
+            kwargs['createreturnurl'] = '%s://%s' % (self.scheme, self.host)
+        info = self.post('createaccount', username=username, password=password, **kwargs)
+        if info['createaccount']['status'] == 'FAIL':
+            raise errors.UserCreateError(
+                info['createaccount'].get('messagecode'),
+                info['createaccount'].get('message'),
+                kwargs=kwargs
+            )
+        return True
+
+    def get_user(self, username=None, userid=None,
+                 prop='registration|groups|blockinfo'):
+        """
+        Retrieves user informations
+        (registration, groups and blockinfo infos by default).
+
+            >>> try:
+            ...     site.get_user('SomeUser')
+            ... except mwclient.errors.UserNotFound:
+            ...     print('User seems to not exist')
+
+        Args:
+            username (str): User name to retrieve, OR
+            userid (int): User ID to retrieve
+
+        Returns:
+            Dictionary of the JSON user informations response
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            UserNotFound (mwclient.errors.UserNotFound): User has not been found
+            APIError (mwclient.errors.APIError): Other API errors
+        """
+        if (
+            (username is None and userid is None)
+            or (username is not None and userid is not None)
+        ):
+            raise ValueError('username OR userid are required')
+
+        kwargs = {
+            'list': 'users',
+            'usprop': prop,
+        }
+        if username is not None:
+            kwargs['ususers'] = username
+        else:
+            kwargs['ususerids'] = userid
+        resp = self.get('query', **kwargs)
+        if 'missing' in resp['query']['users'][0]:
+            raise errors.UserNotFound(code='missing', info=None, kwargs=kwargs)
+        return resp['query']['users'][0]
+
+    def block_user(self, username=None, userid=None, reason=None, tags=None, **kwargs):
+        """
+        Blocks an user
+
+            >>> try:
+            ...     site.block_user('SomeUser')
+            ... except mwclient.errors.APIError:
+            ...     print('Can not block the user: %s' % e)
+
+        Args:
+            username (str): User name to block, OR
+            userid (int): User ID to block
+            **kwargs (dict): Additional arguments are passed on to the API
+
+        Returns:
+            Dictionary of the JSON block informations response
+            (see https://www.mediawiki.org/wiki/API:Block)
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            APIError (mwclient.errors.APIError): API errors (see codes available
+                      on https://www.mediawiki.org/wiki/API:Block#Possible_errors)
+        """
+        return self._block_unblock_user(True, username, userid, reason, tags, **kwargs)
+
+    def unblock_user(self, username=None, userid=None, reason=None, tags=None, **kwargs):
+        """
+        Unblocks an user
+
+            >>> try:
+            ...     site.unblock_user('SomeUser')
+            ... except mwclient.errors.APIError:
+            ...     print('Can not unblock the user: %s' % e)
+
+        Args:
+            username (str): User name to unblock, OR
+            userid (int): User ID to unblock
+            **kwargs (dict): Additional arguments are passed on to the API
+
+        Returns:
+            Dictionary of the JSON unblock informations response
+            (see https://www.mediawiki.org/wiki/API:Block)
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            APIError (mwclient.errors.APIError): API errors (see codes available
+                      on https://www.mediawiki.org/wiki/API:Block#Possible_errors)
+        """
+        return self._block_unblock_user(False, username, userid, reason, tags, **kwargs)
+
+    def _block_unblock_user(self, block=True, username=None, userid=None,
+                            reason=None, tags=None, **kwargs):
+        """
+        Blocks or unblocks a user. You should not use this protected method: prefers
+        the use of block_user or unblock_user. See docs of theses public methods.
+        """
+        if (
+            (username is None and userid is None)
+            or (username is not None and userid is not None)
+        ):
+            raise ValueError('username OR userid are required')
+        kwargs['token'] = self.get_token('csrf')
+
+        if username is not None:
+            kwargs['user'] = username
+        else:
+            kwargs['userid'] = userid
+        if reason:
+            kwargs['reason'] = reason
+        if tags:
+            kwargs['tags'] = '|'.join(tags)
+        action = 'block' if block else 'unblock'
+        resp = self.post(action, **kwargs)
+        if resp:
+            if action == 'block' and action in resp:
+                return resp[action]
+            elif action == 'unblock' and 'id' in resp:
+                return resp
+        raise errors.APIError('unkown', 'Can not %s user' % action, kwargs=kwargs)
+
+    def get_user_groups(self, username=None, userid=None):
+        """
+        Retrieves groups the user belongs to
+
+            >>> try:
+            ...     site.get_user_groups('SomeUser')
+            ... except mwclient.errors.APIError:
+            ...     print('Can not retrieves user\'s groups: %s' % e)
+
+        Args:
+            username (str): User name concerned, OR
+            userid (int): User ID concerned
+
+        Returns:
+            List of the current groups names the user belongs to.
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            UserNotFound (mwclient.errors.UserNotFound): wanted user doest not exists
+            APIError (mwclient.errors.APIError): other API errors
+        """
+        user = self.get_user(username, userid, 'groups')
+        return user.get('groups', [])
+
+    def add_user_groups(self, username=None, userid=None, groups=None,
+                        expiry=None, reason=None, tags=None):
+        """
+        Adds groups to the user
+
+            >>> try:
+            ...     site.add_user_groups('SomeUser', groups=['sysop', 'bureaucrat'])
+            ... except mwclient.errors.APIError:
+            ...     print('Can not add groups to user: %s' % e)
+
+        Args:
+            username (str): User name concerned, OR
+            userid (int): User ID concerned
+            groups (list): group's names to add to the user
+            expiry (date): optionnal - expiration date of current membership(s)
+            reason (string); optionnal - reason why those groups are added
+            tags (list): list of tags to apply to the entry in the user rights log
+
+        Returns:
+            list of group really added
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            UserNotFound (mwclient.errors.UserNotFound): wanted user doest not exists
+            APIError (mwclient.errors.APIError): other API errors
+        """
+        res = self._set_user_groups(
+            username=username, userid=userid,
+            added_groups=groups, expiry=expiry, reason=reason, tags=tags
+        )
+        if res:
+            return res.get('added', [])
+        return []
+
+    def remove_user_groups(self, username=None, userid=None, groups=None,
+                           reason=None, tags=None):
+        """
+        Removes groups to the user
+
+            >>> try:
+            ...     site.remove_user_groups('SomeUser', groups=['sysop', 'bureaucrat'])
+            ... except mwclient.errors.APIError:
+            ...     print('Can not remove groups to user: %s' % e)
+
+        Args:
+            username (str): User name concerned, OR
+            userid (int): User ID concerned
+            groups (list): group's names to remove to the user
+            reason (string); optionnal - reason why those groups are removed
+            tags (list): list of tags to apply to the entry in the user rights log
+
+        Returns:
+            list of group really removed
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            UserNotFound (mwclient.errors.UserNotFound): wanted user doest not exists
+            APIError (mwclient.errors.APIError): other API errors
+        """
+        res = self._set_user_groups(
+            username=username, userid=userid,
+            removed_groups=groups, reason=reason, tags=tags
+        )
+        return res.get('removed', []) if res else []
+
+    def set_user_groups(self, username=None, userid=None, groups=None,
+                        expiry=None, reason=None, tags=None):
+        """
+        Set groups to the user (add and remove groups depending current memberships)
+
+            >>> try:
+            ...     res = site.set_user_groups('SomeUser',
+            ...                                groups=['sysop', 'bureaucrat'])
+            ... except mwclient.errors.APIError:
+            ...     print('Can not add groups to user: %s' % e)
+            ... else:
+            ...     print('really added groups: %s ; really removed groups: %s',
+            ...            res['added'], res['removed'])
+
+        Args:
+            username (str): User name concerned, OR
+            userid (int): User ID concerned
+            groups (list): ALL group's names the user must belongs too
+                           (group's names which are not listed here will be removed from
+                           the user's memberships)
+            reason (string); optionnal - reason why those groups are added / removed
+            tags (list): list of tags to apply to the entry in the user rights log
+
+        Returns:
+            dict {'added': <list>, 'removed': <list>} with really affected groups
+
+        Raises:
+            ValueError: username and userid params are both not set or set
+            UserNotFound (mwclient.errors.UserNotFound): wanted user doest not exists
+            APIError (mwclient.errors.APIError): other API errors
+        """
+        groups = set(groups)
+        current_groups = set(self.get_user_groups(username=username, userid=userid))
+        removed_groups = current_groups - groups
+        added_groups = groups - current_groups
+        res = self._set_user_groups(
+            username=username, userid=userid,
+            added_groups=added_groups,
+            removed_groups=removed_groups, expiry=expiry, reason=reason,
+            tags=tags
+        )
+        return (
+            {'added': res.get('added', []), 'removed': res.get('removed', [])} if res
+            else {'added': [], 'removed': []}
+        )
+
+    def _set_user_groups(self, username=None, userid=None,
+                         added_groups=None, removed_groups=None,
+                         expiry=None, reason=None, tags=None):
+        """
+        Add and/or remove groups to the user. Please do not use this protected method:
+        you sould use [add|remove|set]_user_groups instead.
+        Please refers to those methods doc string
+        """
+
+        if not added_groups and not removed_groups:
+            return False
+        if (not username and not userid) or (username and userid):
+            raise ValueError('username OR userid are required')
+
+        kwargs = {
+            'token': self.get_token('userrights'),
+        }
+        if username:
+            kwargs['user'] = username
+        if userid:
+            kwargs['userid'] = userid
+        if added_groups:
+            kwargs['add'] = '|'.join(added_groups)
+            if expiry:
+                if isinstance(expiry, str):
+                    kwargs['expiry'] = expiry
+                else:
+                    try:
+                        iterator = iter(expiry)
+                    except TypeError:
+                        expiry = ['%s' % expiry]
+                    else:
+                        expiry = ['%s' % e for e in iterator]
+                    kwargs['expiry'] = '|'.join(expiry)
+
+        if removed_groups:
+            kwargs['remove'] = '|'.join(removed_groups)
+
+        if reason:
+            kwargs['reason'] = reason
+        if tags:
+            kwargs['tags'] = '|'.join(tags)
+
+        try:
+            res = self.post('userrights', **kwargs)
+        except errors.APIError as e:
+            if e.code == 'nosuchuser':
+                raise errors.UserNotFound(code=e.code, info=None, kwargs=kwargs)
+            raise
+        else:
+            res = res.get('userrights', {})
+            return {'added': res.get('added', []), 'removed': res.get('removed', [])}
+
     def upload(self, file=None, filename=None, description='', ignore=False,
                file_size=None, url=None, filekey=None, comment=None):
         """Upload a file to the site.
diff --git a/mwclient/errors.py b/mwclient/errors.py
index 845cf31..ed505d0 100644
--- a/mwclient/errors.py
+++ b/mwclient/errors.py
@@ -22,6 +22,14 @@ class APIError(MwClientError):
         super(APIError, self).__init__(code, info, kwargs)
 
 
+class UserNotFound(APIError):
+    pass
+
+
+class UserCreateError(APIError):
+    pass
+
+
 class InsufficientPermission(MwClientError):
     pass
 
diff --git a/test/test_client.py b/test/test_client.py
index c86077f..01afecf 100644
--- a/test/test_client.py
+++ b/test/test_client.py
@@ -1,18 +1,19 @@
-from io import StringIO
+import json
+import logging
+import time
 import unittest
-import pytest
+import unittest.mock as mock
+from copy import deepcopy
+from datetime import date
+from io import StringIO
+
 import mwclient
-import logging
+import pkg_resources  # part of setuptools
+import pytest
 import requests
 import responses
-import pkg_resources  # part of setuptools
-import time
-import json
 from requests_oauthlib import OAuth1
 
-import unittest.mock as mock
-
-
 if __name__ == "__main__":
     print()
     print("Note: Running in stand-alone mode. Consult the README")
@@ -860,5 +861,561 @@ class TestClientPatrol(TestCase):
         get_token.assert_called_once_with('edit')
 
 
+class TestUser(TestCase):
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_create_user(self, raw_api, site_init):
+        createaccount_token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'createaccounttoken': createaccount_token}}
+                }
+            elif 'username' in kwargs:
+                assert kwargs['createtoken'] == createaccount_token
+                assert kwargs['retype'] == kwargs['password']
+                assert kwargs.get('createreturnurl')
+                return {
+                    'createaccount': {'status': 'PASS'}
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        password = 'password'
+        url = '%s://%s' % (site.scheme, site.host)
+        site.create_user(username='myusername', password=password)
+
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET',
+                                         meta='tokens', type='createaccount')
+        assert call_args[1] == mock.call('createaccount', 'POST',
+                                         username='myusername', password=password,
+                                         retype=password, createreturnurl=url,
+                                         createtoken=createaccount_token)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_create_user_fail_badretype(self, raw_api, site_init):
+        createaccount_token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'createaccounttoken': createaccount_token}}
+                }
+            elif 'username' in kwargs:
+                assert kwargs['createtoken'] == createaccount_token
+                assert kwargs['retype'] != kwargs['password']
+                assert kwargs.get('createreturnurl')
+                return {
+                    'createaccount': {'status': 'FAIL',
+                                      'messagecode': 'badretype',
+                                      'message': 'oups'}
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        password = 'password'
+        url = '%s://%s' % (site.scheme, site.host)
+
+        with pytest.raises(mwclient.errors.UserCreateError):
+            site.create_user(username='myusername', password=password, retype=password[::-1])
+
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET',
+                                         meta='tokens', type='createaccount')
+        assert call_args[1] == mock.call('createaccount', 'POST',
+                                         username='myusername', password=password,
+                                         retype=password[::-1], createreturnurl=url,
+                                         createtoken=createaccount_token)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_get_user(self, raw_api, site_init):
+        def side_effect(*args, **kwargs):
+            if kwargs.get('list') == 'users':
+                return {
+                    'query': {
+                        'users': [{
+                            'userid': 1,
+                            'user': 'myusername',
+                            'groups': ['*', 'user']
+                        }]
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        site.get_user(username='myusername')
+
+        call_kwargs = {
+            'ususers': 'myusername',
+            'list': 'users',
+            'continue': '',
+            'meta': 'userinfo',
+            'uiprop': 'blockinfo|hasmsg',
+            'usprop': 'registration|groups|blockinfo'
+        }
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 1
+        assert call_args[0] == mock.call('query', 'GET', **call_kwargs)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_get_user_fail_notfound(self, raw_api, site_init):
+        def side_effect(*args, **kwargs):
+            if kwargs.get('list') == 'users':
+                ret = {
+                    'query': {
+                        'users': [{
+                            'missing': ''
+                        }]
+                    }
+                }
+                if 'ususers' in kwargs:
+                    ret['query']['users'][0]['user'] = kwargs['ususers']
+                elif 'ususerids' in kwargs:
+                    ret['query']['users'][0]['userid'] = kwargs['ususerids']
+                return ret
+
+        raw_api.side_effect = side_effect
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(mwclient.errors.UserNotFound):
+            site.get_user(username='notfounduser')
+        with pytest.raises(mwclient.errors.UserNotFound):
+            site.get_user(userid=42)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_get_user_fail_params(self, raw_api, site_init):
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(ValueError):
+            site.get_user(username=None, userid=None)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_block_user(self, raw_api, site_init):
+        csrf_token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'csrftoken': csrf_token}}
+                }
+            else:
+                return {
+                    'block': {
+                        'user': 'myusername',
+                        'userID': 1,
+                        'expiry': 'infinite',
+                        'id': 1,
+                        'reason': kwargs['reason']
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        site.block_user(username='myusername', reason='Test',
+                        tags=['knock', 'knock'])
+
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='csrf')
+        assert call_args[1] == mock.call('block', 'POST',
+                                         user='myusername', reason='Test',
+                                         tags='knock|knock', token=csrf_token)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_unblock_user(self, raw_api, site_init):
+        csrf_token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'csrftoken': csrf_token}}
+                }
+            else:
+                return {
+                    'id': 1,
+                    'user': 'myusername',
+                    'userID': 1,
+                    'reason': kwargs['reason']
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        site.unblock_user(username='myusername', reason='Test')
+
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='csrf')
+        assert call_args[1] == mock.call('unblock', 'POST',
+                                         user='myusername', reason='Test',
+                                         token=csrf_token)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_block_user_fail_params(self, raw_api, site_init):
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(ValueError):
+            site.block_user(reason='Test', tags=['knock', 'knock'])
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_block_user_fail_unkown(self, raw_api, site_init):
+        csrf_token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'csrftoken': csrf_token}}
+                }
+            else:
+                return {
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(mwclient.errors.APIError):
+            site.block_user(userid=42, reason='Test',
+                            tags=['knock', 'knock'])
+
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='csrf')
+        assert call_args[1] == mock.call('block', 'POST',
+                                         userid=42, reason='Test',
+                                         tags='knock|knock', token=csrf_token)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_get_user_groups(self, raw_api, site_init):
+        def side_effect(*args, **kwargs):
+            if kwargs.get('list') == 'users':
+                return {
+                    'query': {
+                        'users': [{
+                            'userid': 1,
+                            'user': 'myusername',
+                            'groups': ['*', 'user']
+                        }]
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        groups = site.get_user_groups(username='myusername')
+
+        call_kwargs = {
+            'ususers': 'myusername',
+            'list': 'users',
+            'continue': '',
+            'meta': 'userinfo',
+            'uiprop': 'blockinfo|hasmsg',
+            'usprop': 'groups'
+        }
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 1
+        assert call_args[0] == mock.call('query', 'GET', **call_kwargs)
+        assert groups == ['*', 'user']
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_get_user_groups_fail_params(self, raw_api, site_init):
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(ValueError):
+            site.get_user_groups(username=None)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_get_user_groups_fail_notfound(self, raw_api, site_init):
+        def side_effect(*args, **kwargs):
+            if kwargs.get('list') == 'users':
+                return {
+                    'query': {
+                        'users': [{
+                            'userid': 42,
+                            'missing': ''
+                        }]
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(mwclient.errors.UserNotFound):
+            site.get_user_groups(userid=42)
+
+        call_kwargs = {
+            'ususerids': 42,
+            'list': 'users',
+            'continue': '',
+            'meta': 'userinfo',
+            'uiprop': 'blockinfo|hasmsg',
+            'usprop': 'groups'
+        }
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 1
+        assert call_args[0] == mock.call('query', 'GET', **call_kwargs)
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_add_user_groups(self, raw_api, site_init):
+        token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'userrightstoken': token}}
+                }
+            else:
+                return {
+                    'userrights': {
+                        'userid': 42,
+                        'user': 'myusername',
+                        'added': ['*', 'user', 'bureaucrat', 'sysop'],
+                        'removed': []
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        added = site.add_user_groups(username='myusername',
+                                     groups=['*', 'user', 'bureaucrat', 'sysop'])
+
+        mock_call = mock.call('userrights', 'POST', **{
+            'user': 'myusername',
+            'add': '*|user|bureaucrat|sysop',
+            'token': token,
+        })
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='userrights')
+
+        mock_call_kwargs = deepcopy(mock_call[2])
+        real_call_kwargs = deepcopy(call_args[1][1])
+        assert 'add' in real_call_kwargs
+        assert 'add' in mock_call_kwargs
+        add_kwargs = set(real_call_kwargs.pop('add').split('|'))
+        assert add_kwargs == set(mock_call_kwargs.pop('add').split('|'))
+
+        assert 'remove' not in real_call_kwargs
+        assert 'remove' not in mock_call_kwargs
+
+        assert real_call_kwargs == mock_call_kwargs
+        assert mock_call.args == call_args[1].args
+
+        assert added == ['*', 'user', 'bureaucrat', 'sysop']
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_add_user_groups_expirty_formats(self, raw_api, site_init):
+        token = 'abc+\\'
+        today = date.today()
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'userrightstoken': token}}
+                }
+            else:
+                return {
+                    'userrights': {
+                        'userid': 42,
+                        'user': 'myusername',
+                        'added': ['*', 'user', 'bureaucrat', 'sysop'],
+                        'removed': []
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        call_kwargs = {
+            'username': 'myusername',
+            'groups': ['*', 'user', 'bureaucrat', 'sysop'],
+        }
+        mock_call_kwargs = {
+            'user': 'myusername',
+            'add': '*|user|bureaucrat|sysop',
+            'token': token,
+        }
+
+        expirty_formats = (
+            (['2042-01-01', '2042-01-02'], '2042-01-01|2042-01-02'),
+            ('2042-01-01', '2042-01-01'),
+            (today, '%s' % today),
+        )
+
+        for call_fmt_expiry, mock_call_expiry in expirty_formats:
+            call_kwargs['expiry'] = call_fmt_expiry
+            mock_call_kwargs['expiry'] = mock_call_expiry
+            added = site.add_user_groups(**call_kwargs)
+            assert added == ['*', 'user', 'bureaucrat', 'sysop']
+            real_mock_call_kwargs = mock.call('userrights', 'POST', **mock_call_kwargs)[2]
+            real_call_kwargs = raw_api.call_args_list[-1][1]
+            assert real_mock_call_kwargs['expiry'] == real_call_kwargs['expiry']
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_add_user_groups_fail_notfound(self, raw_api, site_init):
+        token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'userrightstoken': token}}
+                }
+            else:
+                raise mwclient.errors.APIError('nosuchuser', 'Blah', kwargs)
+
+        raw_api.side_effect = side_effect
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(mwclient.errors.UserNotFound):
+            site.add_user_groups(username='notfound',
+                                 groups=['*', 'user', 'bureaucrat', 'sysop'])
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_add_user_groups_fail_params(self, raw_api, site_init):
+        site = mwclient.Site('test.wikipedia.org')
+        with pytest.raises(ValueError):
+            site.add_user_groups(username=None,
+                                 groups=['*', 'user', 'bureaucrat', 'sysop'])
+        assert [] == site.add_user_groups(userid=42, groups=[])
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_remove_user_groups(self, raw_api, site_init):
+        token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'userrightstoken': token}}
+                }
+            else:
+                return {
+                    'userrights': {
+                        'userid': 42,
+                        'user': 'myusername',
+                        'removed': ['bureaucrat', 'sysop'],
+                        'added': []
+                    }
+                }
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        removed = site.remove_user_groups(username='myusername',
+                                          groups=['bureaucrat', 'sysop'],
+                                          reason='Test')
+
+        mock_call = mock.call('userrights', 'POST', **{
+            'user': 'myusername',
+            'remove': 'bureaucrat|sysop',
+            'token': token,
+            'reason': 'Test',
+        })
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 2
+        assert call_args[0] == mock.call('query', 'GET', meta='tokens', type='userrights')
+
+        mock_call_kwargs = deepcopy(mock_call[2])
+        real_call_kwargs = deepcopy(call_args[1][1])
+        assert 'remove' in real_call_kwargs
+        assert 'remove' in mock_call_kwargs
+        remove_kwargs = set(real_call_kwargs.pop('remove').split('|'))
+        assert remove_kwargs == set(mock_call_kwargs.pop('remove').split('|'))
+
+        assert 'add' not in real_call_kwargs
+        assert 'add' not in mock_call_kwargs
+
+        assert real_call_kwargs == mock_call_kwargs
+        assert mock_call.args == call_args[1].args
+
+        assert removed == ['bureaucrat', 'sysop']
+
+    @mock.patch('mwclient.client.Site.site_init')
+    @mock.patch('mwclient.client.Site.raw_api')
+    def test_set_user_groups(self, raw_api, site_init):
+        token = 'abc+\\'
+
+        def side_effect(*args, **kwargs):
+            if kwargs.get('meta') == 'tokens':
+                return {
+                    'query': {'tokens': {'userrightstoken': token}}
+                }
+            elif kwargs.get('list') == 'users':
+                return {
+                    'query': {
+                        'users': [{
+                            'userid': 42,
+                            'user': 'myusername',
+                            'groups': ['*', 'user', 'bot', 'interface-admin']
+                        }]
+                    }
+                }
+            else:
+                return {}
+
+        raw_api.side_effect = side_effect
+
+        site = mwclient.Site('test.wikipedia.org')
+        site.set_user_groups(userid=42,
+                             tags=['one', 'two'],
+                             reason='Test',
+                             groups=['*', 'user', 'bureaucrat', 'sysop'])
+
+        get_groups_call_kwargs = {
+            'ususerids': 42,
+            'list': 'users',
+            'continue': '',
+            'meta': 'userinfo',
+            'uiprop': 'blockinfo|hasmsg',
+            'usprop': 'groups'
+        }
+        set_groups_call_kwargs = {
+            'userid': 42,
+            'remove': 'bot|interface-admin',
+            'add': 'bureaucrat|sysop',
+            'reason': 'Test',
+            'tags': 'one|two',
+            'token': token,
+        }
+        call_args = raw_api.call_args_list
+        assert len(call_args) == 3
+        assert call_args[0] == mock.call('query', 'GET', **get_groups_call_kwargs)
+        assert call_args[1] == mock.call('query', 'GET', meta='tokens', type='userrights')
+
+        mock_call = mock.call('userrights', 'POST', **set_groups_call_kwargs)
+        mock_call_kwargs = deepcopy(mock_call[2])
+        real_call_kwargs = deepcopy(call_args[2][1])
+        assert 'add' in real_call_kwargs
+        assert 'add' in mock_call_kwargs
+        add_kwargs = set(real_call_kwargs.pop('add').split('|'))
+        assert add_kwargs == set(mock_call_kwargs.pop('add').split('|'))
+
+        assert 'remove' in real_call_kwargs
+        assert 'remove' in mock_call_kwargs
+        remove_kwargs = set(real_call_kwargs.pop('remove').split('|'))
+        assert remove_kwargs == set(mock_call_kwargs.pop('remove').split('|'))
+
+        assert real_call_kwargs == mock_call_kwargs
+        assert mock_call.args == call_args[2].args
+
+
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab