Skip to content

Commit

Permalink
client: allow for custom kafka clients
Browse files Browse the repository at this point in the history
Provide the consumer, producer and admin client with the option to
create the kafka client from a custom callable, thus allowing more
flexibility in handling certain low level errors
  • Loading branch information
Gabriel Tincu committed Nov 19, 2020
1 parent 6f932ba commit 274b169
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
10 changes: 7 additions & 3 deletions kafka/admin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class KafkaAdminClient(object):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client (callable): Custom class / callable for creating KafkaClient instances
"""
DEFAULT_CONFIG = {
Expand Down Expand Up @@ -186,6 +187,7 @@ class KafkaAdminClient(object):
'metric_reporters': [],
'metrics_num_samples': 2,
'metrics_sample_window_ms': 30000,
'client': KafkaClient,
}

def __init__(self, **configs):
Expand All @@ -205,9 +207,11 @@ def __init__(self, **configs):
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)

self._client = KafkaClient(metrics=self._metrics,
metric_group_prefix='admin',
**self.config)
self._client = self.config['client'](
metrics=self._metrics,
metric_group_prefix='admin',
**self.config
)
self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000))

# Get auto-discovered version from client if necessary
Expand Down
4 changes: 3 additions & 1 deletion kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class KafkaConsumer(six.Iterator):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client (callable): Custom class / callable for creating KafkaClient instances
Note:
Configuration parameters are described in more detail at
Expand Down Expand Up @@ -306,6 +307,7 @@ class KafkaConsumer(six.Iterator):
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None,
'legacy_iterator': False, # enable to revert to < 1.4.7 iterator
'client': KafkaClient,
}
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000

Expand Down Expand Up @@ -353,7 +355,7 @@ def __init__(self, *topics, **configs):
log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated',
str(self.config['api_version']), str_version)

self._client = KafkaClient(metrics=self._metrics, **self.config)
self._client = self.config['client'](metrics=self._metrics, **self.config)

# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
Expand Down
11 changes: 7 additions & 4 deletions kafka/producer/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class KafkaProducer(object):
sasl mechanism handshake. Default: one of bootstrap servers
sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider
instance. (See kafka.oauth.abstract). Default: None
client (callable): Custom class / callable for creating KafkaClient instances
Note:
Configuration parameters are described in more detail at
Expand Down Expand Up @@ -332,7 +333,8 @@ class KafkaProducer(object):
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None
'sasl_oauth_token_provider': None,
'client': KafkaClient,
}

_COMPRESSORS = {
Expand Down Expand Up @@ -378,9 +380,10 @@ def __init__(self, **configs):
reporters = [reporter() for reporter in self.config['metric_reporters']]
self._metrics = Metrics(metric_config, reporters)

client = KafkaClient(metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)
client = self.config['client'](
metrics=self._metrics, metric_group_prefix='producer',
wakeup_timeout_ms=self.config['max_block_ms'],
**self.config)

# Get auto-discovered version from client if necessary
if self.config['api_version'] is None:
Expand Down

0 comments on commit 274b169

Please sign in to comment.