Skip to content

Commit

Permalink
Add 'Transaction.get' / 'Transaction.get_all'. (#8628)
Browse files Browse the repository at this point in the history
Closes #6557.
  • Loading branch information
HemangChothani authored and tseaver committed Jul 15, 2019
1 parent 8b69072 commit 49e7ce3
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
33 changes: 33 additions & 0 deletions firestore/google/cloud/firestore_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from google.api_core import exceptions
from google.cloud.firestore_v1 import batch
from google.cloud.firestore_v1 import types
from google.cloud.firestore_v1.document import DocumentReference
from google.cloud.firestore_v1.query import Query


MAX_ATTEMPTS = 5
Expand Down Expand Up @@ -200,6 +202,37 @@ def _commit(self):
self._clean_up()
return list(commit_response.write_results)

def get_all(self, references):
"""Retrieves multiple documents from Firestore.
Args:
references (List[.DocumentReference, ...]): Iterable of document
references to be retrieved.
Yields:
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
return self._client.get_all(references, transaction=self._id)

def get(self, ref_or_query):
"""
Retrieve a document or a query result from the database.
Args:
ref_or_query The document references or query object to return.
Yields:
.DocumentSnapshot: The next document snapshot that fulfills the
query, or :data:`None` if the document does not exist.
"""
if isinstance(ref_or_query, DocumentReference):
return self._client.get_all([ref_or_query], transaction=self._id)
elif isinstance(ref_or_query, Query):
return ref_or_query.stream(transaction=self._id)
else:
raise ValueError(
'Value for argument "ref_or_query" must be a DocumentReference or a Query.'
)


class _Transactional(object):
"""Provide a callable object to use as a transactional decorater.
Expand Down
37 changes: 36 additions & 1 deletion firestore/tests/unit/v1/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import unittest

import mock


Expand Down Expand Up @@ -329,6 +328,42 @@ def test__commit_failure(self):
metadata=client._rpc_metadata,
)

def test_get_all(self):
client = mock.Mock(spec=["get_all"])
transaction = self._make_one(client)
ref1, ref2 = mock.Mock(), mock.Mock()
result = transaction.get_all([ref1, ref2])
client.get_all.assert_called_once_with([ref1, ref2], transaction=transaction.id)
self.assertIs(result, client.get_all.return_value)

def test_get_document_ref(self):
from google.cloud.firestore_v1.document import DocumentReference

client = mock.Mock(spec=["get_all"])
transaction = self._make_one(client)
ref = DocumentReference("documents", "doc-id")
result = transaction.get(ref)
client.get_all.assert_called_once_with([ref], transaction=transaction.id)
self.assertIs(result, client.get_all.return_value)

def test_get_w_query(self):
from google.cloud.firestore_v1.query import Query

client = mock.Mock(spec=[])
transaction = self._make_one(client)
query = Query(parent=mock.Mock(spec=[]))
query.stream = mock.MagicMock()
result = transaction.get(query)
query.stream.assert_called_once_with(transaction=transaction.id)
self.assertIs(result, query.stream.return_value)

def test_get_failure(self):
client = _make_client()
transaction = self._make_one(client)
ref_or_query = object()
with self.assertRaises(ValueError):
transaction.get(ref_or_query)


class Test_Transactional(unittest.TestCase):
@staticmethod
Expand Down

0 comments on commit 49e7ce3

Please sign in to comment.