Skip to content

Commit

Permalink
Refactor 'Document.get' to use the 'GetDocument' API. (#6534)
Browse files Browse the repository at this point in the history
Update conformance test to actually run for 'get'.

Toward #6533.
  • Loading branch information
tseaver authored Nov 27, 2018
1 parent 70ab243 commit 5420fea
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 74 deletions.
36 changes: 33 additions & 3 deletions firestore/google/cloud/firestore_v1beta1/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import six

from google.api_core import exceptions
from google.cloud.firestore_v1beta1 import _helpers
from google.cloud.firestore_v1beta1.proto import common_pb2
from google.cloud.firestore_v1beta1.watch import Watch


Expand Down Expand Up @@ -423,9 +425,37 @@ def get(self, field_paths=None, transaction=None):
if isinstance(field_paths, six.string_types):
raise ValueError(
"'field_paths' must be a sequence of paths, not a string.")
snapshot_generator = self._client.get_all(
[self], field_paths=field_paths, transaction=transaction)
return _consume_single_get(snapshot_generator)

if field_paths is not None:
mask = common_pb2.DocumentMask(field_paths=sorted(field_paths))
else:
mask = None

firestore_api = self._client._firestore_api
try:
document_pb = firestore_api.get_document(
self._document_path,
mask=mask,
transaction=_helpers.get_transaction_id(transaction),
metadata=self._client._rpc_metadata)
except exceptions.NotFound:
data = None
exists = False
create_time = None
update_time = None
else:
data = _helpers.decode_dict(document_pb.fields, self._client)
exists = True
create_time = document_pb.create_time
update_time = document_pb.update_time

return DocumentSnapshot(
reference=self,
data=data,
exists=exists,
read_time=None, # No server read_time available
create_time=create_time,
update_time=update_time)

def collections(self, page_size=None):
"""List subcollections of the current document.
Expand Down
1 change: 0 additions & 1 deletion firestore/tests/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def test_document_get(client, cleanup):
write_result = document.create(data)
snapshot = document.get()
check_snapshot(snapshot, document, data, write_result)
assert_timestamp_less(snapshot.create_time, snapshot.read_time)


def test_document_delete(client, cleanup):
Expand Down
26 changes: 13 additions & 13 deletions firestore/tests/unit/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from google.protobuf import text_format
from google.cloud.firestore_v1beta1.proto import document_pb2
from google.cloud.firestore_v1beta1.proto import firestore_pb2
from google.cloud.firestore_v1beta1.proto import test_pb2
from google.cloud.firestore_v1beta1.proto import write_pb2
Expand Down Expand Up @@ -170,19 +171,18 @@ def test_create_testprotos(test_proto):
@pytest.mark.parametrize('test_proto', _GET_TESTPROTOS)
def test_get_testprotos(test_proto):
testcase = test_proto.get
# XXX this stub currently does nothing because no get testcases have
# is_error; taking this bit out causes the existing tests to fail
# due to a lack of batch getting
try:
testcase.is_error
except AttributeError:
return
else: # pragma: NO COVER
testcase = test_proto.get
firestore_api = _mock_firestore_api()
client, document = _make_client_document(firestore_api, testcase)
call = functools.partial(document.get, None, None)
_run_testcase(testcase, call, firestore_api, client)
firestore_api = mock.Mock(spec=['get_document'])
response = document_pb2.Document()
firestore_api.get_document.return_value = response
client, document = _make_client_document(firestore_api, testcase)

document.get() # No '.textprotos' for errors, field_paths.

firestore_api.get_document.assert_called_once_with(
document._document_path,
mask=None,
transaction=None,
metadata=client._rpc_metadata)


@pytest.mark.parametrize('test_proto', _SET_TESTPROTOS)
Expand Down
130 changes: 73 additions & 57 deletions firestore/tests/unit/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,74 +463,90 @@ def test_delete_with_option(self):
)
self._delete_helper(last_update_time=timestamp_pb)

def test_get_w_single_field_path(self):
client = mock.Mock(spec=[])
def _get_helper(
self, field_paths=None, use_transaction=False, not_found=False):
from google.api_core.exceptions import NotFound
from google.cloud.firestore_v1beta1.proto import common_pb2
from google.cloud.firestore_v1beta1.proto import document_pb2
from google.cloud.firestore_v1beta1.transaction import Transaction

document = self._make_one('yellow', 'mellow', client=client)
with self.assertRaises(ValueError):
document.get('foo')
# Create a minimal fake GAPIC with a dummy response.
create_time = 123
update_time = 234
firestore_api = mock.Mock(spec=['get_document'])
response = mock.create_autospec(document_pb2.Document)
response.fields = {}
response.create_time = create_time
response.update_time = update_time

if not_found:
firestore_api.get_document.side_effect = NotFound('testing')
else:
firestore_api.get_document.return_value = response

def test_get_success(self):
# Create a minimal fake client with a dummy response.
response_iterator = iter([mock.sentinel.snapshot])
client = mock.Mock(spec=['get_all'])
client.get_all.return_value = response_iterator
client = _make_client('donut-base')
client._firestore_api_internal = firestore_api

# Actually make a document and call get().
document = self._make_one('yellow', 'mellow', client=client)
snapshot = document.get()
document = self._make_one('where', 'we-are', client=client)

# Verify the response and the mocks.
self.assertIs(snapshot, mock.sentinel.snapshot)
client.get_all.assert_called_once_with(
[document], field_paths=None, transaction=None)
if use_transaction:
transaction = Transaction(client)
transaction_id = transaction._id = b'asking-me-2'
else:
transaction = None

snapshot = document.get(
field_paths=field_paths, transaction=transaction)

self.assertIs(snapshot.reference, document)
if not_found:
self.assertIsNone(snapshot._data)
self.assertFalse(snapshot.exists)
self.assertIsNone(snapshot.read_time)
self.assertIsNone(snapshot.create_time)
self.assertIsNone(snapshot.update_time)
else:
self.assertEqual(snapshot.to_dict(), {})
self.assertTrue(snapshot.exists)
self.assertIsNone(snapshot.read_time)
self.assertIs(snapshot.create_time, create_time)
self.assertIs(snapshot.update_time, update_time)

# Verify the request made to the API
if field_paths is not None:
mask = common_pb2.DocumentMask(field_paths=sorted(field_paths))
else:
mask = None

def test_get_with_transaction(self):
from google.cloud.firestore_v1beta1.client import Client
from google.cloud.firestore_v1beta1.transaction import Transaction
if use_transaction:
expected_transaction_id = transaction_id
else:
expected_transaction_id = None

# Create a minimal fake client with a dummy response.
response_iterator = iter([mock.sentinel.snapshot])
client = mock.create_autospec(Client, instance=True)
client.get_all.return_value = response_iterator
firestore_api.get_document.assert_called_once_with(
document._document_path,
mask=mask,
transaction=expected_transaction_id,
metadata=client._rpc_metadata)

# Actually make a document and call get().
document = self._make_one('yellow', 'mellow', client=client)
transaction = Transaction(client)
transaction._id = b'asking-me-2'
snapshot = document.get(transaction=transaction)
def test_get_not_found(self):
self._get_helper(not_found=True)

# Verify the response and the mocks.
self.assertIs(snapshot, mock.sentinel.snapshot)
client.get_all.assert_called_once_with(
[document], field_paths=None, transaction=transaction)
def test_get_default(self):
self._get_helper()

def test_get_not_found(self):
from google.cloud.firestore_v1beta1.document import DocumentSnapshot
def test_get_w_string_field_path(self):
with self.assertRaises(ValueError):
self._get_helper(field_paths='foo')

# Create a minimal fake client with a dummy response.
read_time = 123
expected = DocumentSnapshot(None, None, False, read_time, None, None)
response_iterator = iter([expected])
client = mock.Mock(
_database_string='sprinklez',
spec=['_database_string', 'get_all'])
client.get_all.return_value = response_iterator

# Actually make a document and call get().
document = self._make_one('house', 'cowse', client=client)
field_paths = ['x.y', 'x.z', 't']
snapshot = document.get(field_paths=field_paths)
self.assertIsNone(snapshot.reference)
self.assertIsNone(snapshot._data)
self.assertFalse(snapshot.exists)
self.assertEqual(snapshot.read_time, expected.read_time)
self.assertIsNone(snapshot.create_time)
self.assertIsNone(snapshot.update_time)
def test_get_with_field_path(self):
self._get_helper(field_paths=['foo'])

# Verify the response and the mocks.
client.get_all.assert_called_once_with(
[document], field_paths=field_paths, transaction=None)
def test_get_with_multiple_field_paths(self):
self._get_helper(field_paths=['foo', 'bar.baz'])

def test_get_with_transaction(self):
self._get_helper(use_transaction=True)

def _collections_helper(self, page_size=None):
from google.api_core.page_iterator import Iterator
Expand Down

0 comments on commit 5420fea

Please sign in to comment.