Skip to content

Commit

Permalink
Refactor 'Document.get' to use the 'GetDocument' API.
Browse files Browse the repository at this point in the history
Update conformance test to actually run for 'get'.

Toward #6533.
  • Loading branch information
tseaver committed Nov 15, 2018
1 parent 88a0ac5 commit ade1acf
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 73 deletions.
37 changes: 34 additions & 3 deletions firestore/google/cloud/firestore_v1beta1/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import six

from google.api_core import datetime_helpers
from google.api_core import exceptions
from google.cloud.firestore_v1beta1 import _helpers
from google.cloud.firestore_v1beta1.proto import common_pb2
from google.cloud.firestore_v1beta1.watch import Watch


Expand Down Expand Up @@ -426,9 +428,38 @@ def get(self, field_paths=None, transaction=None):
if isinstance(field_paths, six.string_types):
raise ValueError(
"'field_paths' must be a sequence of paths, not a string.")
snapshot_generator = self._client.get_all(
[self], field_paths=field_paths, transaction=transaction)
return _consume_single_get(snapshot_generator)

if field_paths is not None:
mask = common_pb2.DocumentMask(field_paths=sorted(field_paths))
else:
mask = None

firestore_api = self._client._firestore_api
timestamp = _make_timestamp()
try:
document_pb = firestore_api.get_document(
self._document_path,
mask=mask,
transaction=_helpers.get_transaction_id(transaction),
metadata=self._client._rpc_metadata)
except exceptions.NotFound:
data = None
exists = False
create_time = None
update_time = None
else:
data = _helpers.decode_dict(document_pb.fields, self._client)
exists = True
create_time = document_pb.create_time
update_time = document_pb.update_time

return DocumentSnapshot(
reference=self,
data=data,
exists=exists,
read_time=timestamp, # No server read_time available
create_time=create_time,
update_time=update_time)

def collections(self, page_size=None):
"""List subcollections of the current document.
Expand Down
26 changes: 13 additions & 13 deletions firestore/tests/unit/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from google.protobuf import text_format
from google.cloud.firestore_v1beta1.proto import document_pb2
from google.cloud.firestore_v1beta1.proto import firestore_pb2
from google.cloud.firestore_v1beta1.proto import test_pb2
from google.cloud.firestore_v1beta1.proto import write_pb2
Expand Down Expand Up @@ -180,19 +181,18 @@ def test_create_testprotos(test_proto):
@pytest.mark.parametrize('test_proto', _GET_TESTPROTOS)
def test_get_testprotos(test_proto):
testcase = test_proto.get
# XXX this stub currently does nothing because no get testcases have
# is_error; taking this bit out causes the existing tests to fail
# due to a lack of batch getting
try:
testcase.is_error
except AttributeError:
return
else: # pragma: NO COVER
testcase = test_proto.get
firestore_api = _mock_firestore_api()
client, document = _make_client_document(firestore_api, testcase)
call = functools.partial(document.get, None, None)
_run_testcase(testcase, call, firestore_api, client)
firestore_api = mock.Mock(spec=['get_document'])
response = document_pb2.Document()
firestore_api.get_document.return_value = response
client, document = _make_client_document(firestore_api, testcase)

document.get() # No '.textprotos' for errors, field_paths.

firestore_api.get_document.assert_called_once_with(
document._document_path,
mask=None,
transaction=None,
metadata=client._rpc_metadata)


@pytest.mark.parametrize('test_proto', _SET_TESTPROTOS)
Expand Down
133 changes: 76 additions & 57 deletions firestore/tests/unit/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,74 +459,93 @@ def test_delete_with_option(self):
)
self._delete_helper(last_update_time=timestamp_pb)

def test_get_w_single_field_path(self):
client = mock.Mock(spec=[])

document = self._make_one('yellow', 'mellow', client=client)
with self.assertRaises(ValueError):
document.get('foo')
def _get_helper(
self, field_paths=None, use_transaction=False, not_found=False):
from google.api_core.exceptions import NotFound
from google.cloud.firestore_v1beta1.proto import common_pb2
from google.cloud.firestore_v1beta1.proto import document_pb2
from google.cloud.firestore_v1beta1.transaction import Transaction

def test_get_success(self):
# Create a minimal fake client with a dummy response.
response_iterator = iter([mock.sentinel.snapshot])
client = mock.Mock(spec=['get_all'])
client.get_all.return_value = response_iterator
# Create a minimal fake GAPIC with a dummy response.
create_time = 123
update_time = 234
firestore_api = mock.Mock(spec=['get_document'])
response = mock.create_autospec(document_pb2.Document)
response.fields = {}
response.create_time = create_time
response.update_time = update_time

if not_found:
firestore_api.get_document.side_effect = NotFound('testing')
else:
firestore_api.get_document.return_value = response

# Actually make a document and call get().
document = self._make_one('yellow', 'mellow', client=client)
snapshot = document.get()
client = _make_client('donut-base')
client._firestore_api_internal = firestore_api

# Verify the response and the mocks.
self.assertIs(snapshot, mock.sentinel.snapshot)
client.get_all.assert_called_once_with(
[document], field_paths=None, transaction=None)
document = self._make_one('where', 'we-are', client=client)

def test_get_with_transaction(self):
from google.cloud.firestore_v1beta1.client import Client
from google.cloud.firestore_v1beta1.transaction import Transaction
if use_transaction:
transaction = Transaction(client)
transaction_id = transaction._id = b'asking-me-2'
else:
transaction = None

# Create a minimal fake client with a dummy response.
response_iterator = iter([mock.sentinel.snapshot])
client = mock.create_autospec(Client, instance=True)
client.get_all.return_value = response_iterator
tmts = Test__make_timestamp
patch, expected_stamp = tmts._make_datetime_module_patch()
with patch:
snapshot = document.get(
field_paths=field_paths, transaction=transaction)

self.assertIs(snapshot.reference, document)
if not_found:
self.assertIsNone(snapshot._data)
self.assertFalse(snapshot.exists)
self.assertEqual(snapshot.read_time, expected_stamp)
self.assertIsNone(snapshot.create_time)
self.assertIsNone(snapshot.update_time)
else:
self.assertEqual(snapshot.to_dict(), {})
self.assertTrue(snapshot.exists)
self.assertEqual(snapshot.read_time, expected_stamp)
self.assertIs(snapshot.create_time, create_time)
self.assertIs(snapshot.update_time, update_time)

# Verify the request made to the API
if field_paths is not None:
mask = common_pb2.DocumentMask(field_paths=sorted(field_paths))
else:
mask = None

# Actually make a document and call get().
document = self._make_one('yellow', 'mellow', client=client)
transaction = Transaction(client)
transaction._id = b'asking-me-2'
snapshot = document.get(transaction=transaction)
if use_transaction:
expected_transaction_id = transaction_id
else:
expected_transaction_id = None

# Verify the response and the mocks.
self.assertIs(snapshot, mock.sentinel.snapshot)
client.get_all.assert_called_once_with(
[document], field_paths=None, transaction=transaction)
firestore_api.get_document.assert_called_once_with(
document._document_path,
mask=mask,
transaction=expected_transaction_id,
metadata=client._rpc_metadata)

def test_get_not_found(self):
from google.cloud.firestore_v1beta1.document import DocumentSnapshot
self._get_helper(not_found=True)

# Create a minimal fake client with a dummy response.
read_time = 123
expected = DocumentSnapshot(None, None, False, read_time, None, None)
response_iterator = iter([expected])
client = mock.Mock(
_database_string='sprinklez',
spec=['_database_string', 'get_all'])
client.get_all.return_value = response_iterator

# Actually make a document and call get().
document = self._make_one('house', 'cowse', client=client)
field_paths = ['x.y', 'x.z', 't']
snapshot = document.get(field_paths=field_paths)
self.assertIsNone(snapshot.reference)
self.assertIsNone(snapshot._data)
self.assertFalse(snapshot.exists)
self.assertEqual(snapshot.read_time, expected.read_time)
self.assertIsNone(snapshot.create_time)
self.assertIsNone(snapshot.update_time)
def test_get_default(self):
self._get_helper()

# Verify the response and the mocks.
client.get_all.assert_called_once_with(
[document], field_paths=field_paths, transaction=None)
def test_get_w_string_field_path(self):
with self.assertRaises(ValueError):
self._get_helper(field_paths='foo')

def test_get_with_field_path(self):
self._get_helper(field_paths=['foo'])

def test_get_with_multiple_field_paths(self):
self._get_helper(field_paths=['foo', 'bar.baz'])

def test_get_with_transaction(self):
self._get_helper(use_transaction=True)

def _collections_helper(self, page_size=None):
from google.api_core.page_iterator import Iterator
Expand Down

0 comments on commit ade1acf

Please sign in to comment.