Skip to content

Commit

Permalink
Check _can_send_recv with lock to verify state
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkp committed Mar 31, 2019
1 parent 47510f5 commit c1c71f7
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ def _try_authenticate_plain(self, future):
size = Int32.encode(len(msg))
try:
with self._lock:
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
Expand Down Expand Up @@ -616,6 +618,8 @@ def _try_authenticate_gssapi(self, future):
log.debug('%s: GSSAPI name: %s', self, gssapi_name)

self._lock.acquire()
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
# Establish security context and negotiate protection level
# For reference RFC 2222, section 7.2.1
try:
Expand Down Expand Up @@ -677,6 +681,8 @@ def _try_authenticate_oauth(self, future):
msg = bytes(self._build_oauth_client_request().encode("utf-8"))
size = Int32.encode(len(msg))
self._lock.acquire()
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))
try:
# Send SASL OAuthBearer request with OAuth token
self._send_bytes_blocking(size + msg)
Expand Down Expand Up @@ -816,6 +822,11 @@ def close(self, error=None):
for (_correlation_id, (future, _timestamp)) in ifrs:
future.failure(error)

def _can_send_recv(self):
"""Return True iff socket is ready for requests / responses"""
return self.state in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED)

def send(self, request, blocking=True):
"""Queue request for async network send, return Future()"""
future = Future()
Expand All @@ -830,8 +841,7 @@ def send(self, request, blocking=True):
def _send(self, request, blocking=True):
future = Future()
with self._lock:
if self.state not in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED):
if not self._can_send_recv():
return future.failure(Errors.NodeNotReadyError(str(self)))

correlation_id = self._protocol.send_request(request)
Expand All @@ -855,8 +865,7 @@ def send_pending_requests(self):
"""Can block on network if request is larger than send_buffer_bytes"""
try:
with self._lock:
if self.state not in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED):
if not self._can_send_recv():
return Errors.NodeNotReadyError(str(self))
# In the future we might manage an internal write buffer
# and send bytes asynchronously. For now, just block
Expand All @@ -882,19 +891,6 @@ def recv(self):
Return list of (response, future) tuples
"""
if self.state not in (ConnectionStates.AUTHENTICATING,
ConnectionStates.CONNECTED):
log.warning('%s cannot recv: socket not connected', self)
# If requests are pending, we should close the socket and
# fail all the pending request futures
if self.in_flight_requests:
self.close(Errors.KafkaConnectionError('Socket not connected during recv with in-flight-requests'))
return ()

elif not self.in_flight_requests:
log.warning('%s: No in-flight-requests to recv', self)
return ()

responses = self._recv()
if not responses and self.requests_timed_out():
log.warning('%s timed out after %s ms. Closing connection.',
Expand Down Expand Up @@ -925,6 +921,11 @@ def _recv(self):
"""Take all available bytes from socket, return list of any responses from parser"""
recvd = []
self._lock.acquire()
if not self._can_send_recv():
log.warning('%s cannot recv: socket not connected', self)
self._lock.release()
return ()

while len(recvd) < self.config['sock_chunk_buffer_count']:
try:
data = self._sock.recv(self.config['sock_chunk_bytes'])
Expand Down

0 comments on commit c1c71f7

Please sign in to comment.