Skip to content

Commit

Permalink
Enhance support of tls_requires in mysql_user and mysql_info (#628)
Browse files Browse the repository at this point in the history
* fix option name

* Add tests for users using SSL

* Rewrite get_tls_requires using mysql.user table

* Add tls_requires to users_info filter

* add more consistant test users

* Add tls tests users in cleanup task

* Fix tls_requires data structure inconsistencies between modules

* Refactor user implementation to host get_tls_requires

* fix MySQL tls_requires not removed from user passed as empty

* Fix wrong variable used to return a hashed password

* Fix sanity

* fix unit tests

* Add changelog fragment

* Add PR URI to the changelog

* Add more precise change log

* fix documentation using wrong variable as an example

* Document example returned value `tls_requires` from users_info filter

* Revert changes that will be in a separate PR

* Fix sanity
  • Loading branch information
laurent-indermuehle authored Apr 16, 2024
1 parent 0618ff6 commit 47710cf
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 55 deletions.
6 changes: 6 additions & 0 deletions changelogs/fragments/mysql_user_tls_requires.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
minor_changes:
- mysql_info - Add ``tls_requires`` returned value for the ``users_info`` filter (https://github.com/ansible-collections/community.mysql/pull/628).
bugfixes:
- mysql_user - Fix idempotence when using variables from the ``users_info`` filter of ``mysql_info`` as an input (https://github.com/ansible-collections/community.mysql/pull/628).
- mysql_user - Fix ``tls_requires`` not removing ``SSL`` and ``X509`` when sets as empty (https://github.com/ansible-collections/community.mysql/pull/628).
45 changes: 45 additions & 0 deletions plugins/module_utils/implementations/mariadb/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,48 @@ def server_supports_password_expire(cursor):
version = get_server_version(cursor)

return LooseVersion(version) >= LooseVersion("10.4.3")


def get_tls_requires(cursor, user, host):
"""Get user TLS requirements.
Reads directly from mysql.user table allowing for a more
readable code.
Args:
cursor (cursor): DB driver cursor object.
user (str): User name.
host (str): User host name.
Returns: Dictionary containing current TLS required
"""
tls_requires = dict()

query = ('SELECT ssl_type, ssl_cipher, x509_issuer, x509_subject '
'FROM mysql.user WHERE User = %s AND Host = %s')
cursor.execute(query, (user, host))
res = cursor.fetchone()

# Mysql_info use a DictCursor so we must convert back to a list
# otherwise we get KeyError 0
if isinstance(res, dict):
res = list(res.values())

# When user don't require SSL, res value is: ('', '', '', '')
if not any(res):
return None

if res[0] == 'ANY':
tls_requires['SSL'] = None

if res[0] == 'X509':
tls_requires['X509'] = None

if res[1]:
tls_requires['CIPHER'] = res[1]

if res[2]:
tls_requires['ISSUER'] = res[2]

if res[3]:
tls_requires['SUBJECT'] = res[3]
return tls_requires
46 changes: 46 additions & 0 deletions plugins/module_utils/implementations/mysql/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from ansible_collections.community.mysql.plugins.module_utils.version import LooseVersion
from ansible_collections.community.mysql.plugins.module_utils.mysql import get_server_version

import re
import shlex


def use_old_user_mgmt(cursor):
version = get_server_version(cursor)
Expand All @@ -30,3 +33,46 @@ def server_supports_password_expire(cursor):
version = get_server_version(cursor)

return LooseVersion(version) >= LooseVersion("5.7")


def get_tls_requires(cursor, user, host):
"""Get user TLS requirements.
We must use SHOW GRANTS because some tls fileds are encoded.
Args:
cursor (cursor): DB driver cursor object.
user (str): User name.
host (str): User host name.
Returns: Dictionary containing current TLS required
"""
if not use_old_user_mgmt(cursor):
query = "SHOW CREATE USER '%s'@'%s'" % (user, host)
else:
query = "SHOW GRANTS for '%s'@'%s'" % (user, host)

cursor.execute(query)
grants = cursor.fetchone()

# Mysql_info use a DictCursor so we must convert back to a list
# otherwise we get KeyError 0
if isinstance(grants, dict):
grants = list(grants.values())
grants_str = ''.join(grants)

pattern = r"(?<=\bREQUIRE\b)(.*?)(?=(?:\bPASSWORD\b|$))"
requires_match = re.search(pattern, grants_str)
requires = requires_match.group().strip() if requires_match else ""

if requires.startswith('NONE'):
return None

if requires.startswith('SSL'):
return {'SSL': None}

if requires.startswith('X509'):
return {'X509': None}

items = iter(shlex.split(requires))
requires = dict(zip(items, items))
return requires or None
43 changes: 11 additions & 32 deletions plugins/module_utils/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ansible_collections.community.mysql.plugins.module_utils.mysql import (
mysql_driver,
get_server_implementation,
)


Expand Down Expand Up @@ -80,31 +81,6 @@ def do_not_mogrify_requires(query, params, tls_requires):
return query, params


def get_tls_requires(cursor, user, host):
if user:
if not impl.use_old_user_mgmt(cursor):
query = "SHOW CREATE USER '%s'@'%s'" % (user, host)
else:
query = "SHOW GRANTS for '%s'@'%s'" % (user, host)

cursor.execute(query)
require_list = [tuple[0] for tuple in filter(lambda x: "REQUIRE" in x[0], cursor.fetchall())]
require_line = require_list[0] if require_list else ""
pattern = r"(?<=\bREQUIRE\b)(.*?)(?=(?:\bPASSWORD\b|$))"
requires_match = re.search(pattern, require_line)
requires = requires_match.group().strip() if requires_match else ""
if any((requires.startswith(req) for req in ('SSL', 'X509', 'NONE'))):
requires = requires.split()[0]
if requires == 'NONE':
requires = None
else:
import shlex

items = iter(shlex.split(requires))
requires = dict(zip(items, items))
return requires or None


def get_grants(cursor, user, host):
cursor.execute("SHOW GRANTS FOR %s@%s", (user, host))
grants_line = list(filter(lambda x: "ON *.*" in x[0], cursor.fetchall()))[0]
Expand Down Expand Up @@ -166,6 +142,7 @@ def user_add(cursor, user, host, host_all, password, encrypted,
return {'changed': True, 'password_changed': None, 'attributes': attributes}

# Determine what user management method server uses
impl = get_user_implementation(cursor)
old_user_mgmt = impl.use_old_user_mgmt(cursor)

mogrify = do_not_mogrify_requires if old_user_mgmt else mogrify_requires
Expand Down Expand Up @@ -244,6 +221,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted,
grant_option = False

# Determine what user management method server uses
impl = get_user_implementation(cursor)
old_user_mgmt = impl.use_old_user_mgmt(cursor)

if host_all and not role:
Expand Down Expand Up @@ -499,7 +477,7 @@ def user_mod(cursor, user, host, host_all, password, encrypted,
continue

# Handle TLS requirements
current_requires = get_tls_requires(cursor, user, host)
current_requires = sanitize_requires(impl.get_tls_requires(cursor, user, host))
if current_requires != tls_requires:
msg = "TLS requires updated"
if not module.check_mode:
Expand Down Expand Up @@ -837,6 +815,7 @@ def privileges_grant(cursor, user, host, db_table, priv, tls_requires, maria_rol
query.append("TO %s")
params = (user)

impl = get_user_implementation(cursor)
if tls_requires and impl.use_old_user_mgmt(cursor):
query, params = mogrify_requires(" ".join(query), params, tls_requires)
query = [query]
Expand Down Expand Up @@ -973,6 +952,7 @@ def limit_resources(module, cursor, user, host, resource_limits, check_mode):
Returns: True, if changed, False otherwise.
"""
impl = get_user_implementation(cursor)
if not impl.server_supports_alter_user(cursor):
module.fail_json(msg="The server version does not match the requirements "
"for resource_limits parameter. See module's documentation.")
Expand Down Expand Up @@ -1108,12 +1088,11 @@ def attributes_get(cursor, user, host):
return j if j else None


def get_impl(cursor):
global impl
cursor.execute("SELECT VERSION()")
if 'mariadb' in cursor.fetchone()[0].lower():
def get_user_implementation(cursor):
db_engine = get_server_implementation(cursor)
if db_engine == 'mariadb':
from ansible_collections.community.mysql.plugins.module_utils.implementations.mariadb import user as mariauser
impl = mariauser
return mariauser
else:
from ansible_collections.community.mysql.plugins.module_utils.implementations.mysql import user as mysqluser
impl = mysqluser
return mysqluser
22 changes: 17 additions & 5 deletions plugins/modules/mysql_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
plugin: "{{ item.plugin | default(omit) }}"
plugin_auth_string: "{{ item.plugin_auth_string | default(omit) }}"
plugin_hash_string: "{{ item.plugin_hash_string | default(omit) }}"
tls_require: "{{ item.tls_require | default(omit) }}"
tls_requires: "{{ item.tls_requires | default(omit) }}"
priv: "{{ item.priv | default(omit) }}"
resource_limits: "{{ item.resource_limits | default(omit) }}"
column_case_sensitive: true
Expand Down Expand Up @@ -240,7 +240,8 @@
"host": "host.com",
"plugin": "mysql_native_password",
"priv": "db1.*:SELECT/db2.*:SELECT",
"resource_limits": { "MAX_USER_CONNECTIONS": 100 } }
"resource_limits": { "MAX_USER_CONNECTIONS": 100 },
"tls_requires": { "SSL": null } }
version_added: '3.8.0'
engines:
description: Information about the server's storage engines.
Expand Down Expand Up @@ -300,6 +301,7 @@
privileges_get,
get_resource_limits,
get_existing_authentication,
get_user_implementation,
)
from ansible.module_utils.six import iteritems
from ansible.module_utils._text import to_native
Expand Down Expand Up @@ -327,10 +329,11 @@ class MySQL_Info(object):
5. add info about the new subset with an example to RETURN block
"""

def __init__(self, module, cursor, server_implementation):
def __init__(self, module, cursor, server_implementation, user_implementation):
self.module = module
self.cursor = cursor
self.server_implementation = server_implementation
self.user_implementation = user_implementation
self.info = {
'version': {},
'databases': {},
Expand Down Expand Up @@ -602,13 +605,17 @@ def __get_users_info(self):
priv_string.remove('*.*:USAGE')

resource_limits = get_resource_limits(self.cursor, user, host)

copy_ressource_limits = dict.copy(resource_limits)

tls_requires = self.user_implementation.get_tls_requires(
self.cursor, user, host)

output_dict = {
'name': user,
'host': host,
'priv': '/'.join(priv_string),
'resource_limits': copy_ressource_limits,
'tls_requires': tls_requires,
}

# Prevent returning a resource limit if empty
Expand All @@ -619,6 +626,10 @@ def __get_users_info(self):
if len(output_dict['resource_limits']) == 0:
del output_dict['resource_limits']

# Prevent returning tls_require if empty
if not tls_requires:
del output_dict['tls_requires']

authentications = get_existing_authentication(self.cursor, user, host)
if authentications:
output_dict.update(authentications)
Expand Down Expand Up @@ -745,11 +756,12 @@ def main():
module.fail_json(msg)

server_implementation = get_server_implementation(cursor)
user_implementation = get_user_implementation(cursor)

###############################
# Create object and do main job

mysql = MySQL_Info(module, cursor, server_implementation)
mysql = MySQL_Info(module, cursor, server_implementation, user_implementation)

module.exit_json(changed=False,
connector_name=connector_name,
Expand Down
4 changes: 2 additions & 2 deletions plugins/modules/mysql_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@
)
from ansible_collections.community.mysql.plugins.module_utils.user import (
convert_priv_dict_to_str,
get_impl,
get_user_implementation,
get_mode,
user_mod,
privileges_grant,
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def main():
# Set defaults
changed = False

get_impl(cursor)
impl = get_user_implementation(cursor)

if priv is not None:
try:
Expand Down
3 changes: 0 additions & 3 deletions plugins/modules/mysql_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@
)
from ansible_collections.community.mysql.plugins.module_utils.user import (
convert_priv_dict_to_str,
get_impl,
get_mode,
InvalidPrivsError,
limit_resources,
Expand Down Expand Up @@ -528,8 +527,6 @@ def main():
if session_vars:
set_session_vars(module, cursor, session_vars)

get_impl(cursor)

if priv is not None:
try:
mode = get_mode(cursor)
Expand Down
Loading

0 comments on commit 47710cf

Please sign in to comment.