Skip to content

Commit

Permalink
fix: order normalization with descending query (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche authored Nov 6, 2023
1 parent 6acdb19 commit dbe8ef7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
33 changes: 18 additions & 15 deletions google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@
"not-in": _operator_enum.NOT_IN,
"array_contains_any": _operator_enum.ARRAY_CONTAINS_ANY,
}
# set of operators that don't involve equlity comparisons
# will be used in query normalization
_INEQUALITY_OPERATORS = (
_operator_enum.LESS_THAN,
_operator_enum.LESS_THAN_OR_EQUAL,
_operator_enum.GREATER_THAN_OR_EQUAL,
_operator_enum.GREATER_THAN,
_operator_enum.NOT_EQUAL,
_operator_enum.NOT_IN,
)
_BAD_OP_STRING = "Operator string {!r} is invalid. Valid choices are: {}."
_BAD_OP_NAN_NULL = 'Only an equality filter ("==") can be used with None or NaN values'
_INVALID_WHERE_TRANSFORM = "Transforms cannot be used as where values."
Expand Down Expand Up @@ -858,28 +868,21 @@ def _normalize_orders(self) -> list:
if self._end_at:
if isinstance(self._end_at[0], document.DocumentSnapshot):
_has_snapshot_cursor = True

if _has_snapshot_cursor:
should_order = [
_enum_from_op_string(key)
for key in _COMPARISON_OPERATORS
if key not in (_EQ_OP, "array_contains")
]
# added orders should use direction of last order
last_direction = orders[-1].direction if orders else BaseQuery.ASCENDING
order_keys = [order.field.field_path for order in orders]
for filter_ in self._field_filters:
# FieldFilter.Operator should not compare equal to
# UnaryFilter.Operator, but it does
if isinstance(filter_.op, StructuredQuery.FieldFilter.Operator):
field = filter_.field.field_path
if filter_.op in should_order and field not in order_keys:
orders.append(self._make_order(field, "ASCENDING"))
if not orders:
orders.append(self._make_order("__name__", "ASCENDING"))
else:
order_keys = [order.field.field_path for order in orders]
if "__name__" not in order_keys:
direction = orders[-1].direction # enum?
orders.append(self._make_order("__name__", direction))
# skip equality filters and filters on fields already ordered
if filter_.op in _INEQUALITY_OPERATORS and field not in order_keys:
orders.append(self._make_order(field, last_direction))
# add __name__ if not already in orders
if "__name__" not in [order.field.field_path for order in orders]:
orders.append(self._make_order("__name__", last_direction))

return orders

Expand Down
55 changes: 55 additions & 0 deletions tests/unit/v1/test_base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,61 @@ def test_basequery__normalize_orders_w_name_orders_w_none_cursor():
assert query._normalize_orders() == expected


def test_basequery__normalize_orders_w_cursor_descending():
"""
Test case for b/306472103
"""
from google.cloud.firestore_v1.base_query import FieldFilter

collection = _make_collection("here")
snapshot = _make_snapshot(_make_docref("here", "doc_id"), {"a": 1, "b": 2})
query = (
_make_base_query(collection)
.where(filter=FieldFilter("a", "==", 1))
.where(filter=FieldFilter("b", "in", [1, 2, 3]))
.order_by("c", "DESCENDING")
)
query_w_snapshot = query.start_after(snapshot)

normalized = query._normalize_orders()
expected = [query._make_order("c", "DESCENDING")]
assert normalized == expected

normalized_w_snapshot = query_w_snapshot._normalize_orders()
expected_w_snapshot = expected + [query._make_order("__name__", "DESCENDING")]
assert normalized_w_snapshot == expected_w_snapshot


def test_basequery__normalize_orders_w_cursor_descending_w_inequality():
"""
Test case for b/306472103, with extra ineuality filter in "where" clause
"""
from google.cloud.firestore_v1.base_query import FieldFilter

collection = _make_collection("here")
snapshot = _make_snapshot(_make_docref("here", "doc_id"), {"a": 1, "b": 2})
query = (
_make_base_query(collection)
.where(filter=FieldFilter("a", "==", 1))
.where(filter=FieldFilter("b", "in", [1, 2, 3]))
.where(filter=FieldFilter("c", "not-in", [4, 5, 6]))
.order_by("d", "DESCENDING")
)
query_w_snapshot = query.start_after(snapshot)

normalized = query._normalize_orders()
expected = [query._make_order("d", "DESCENDING")]
assert normalized == expected

normalized_w_snapshot = query_w_snapshot._normalize_orders()
expected_w_snapshot = [
query._make_order("d", "DESCENDING"),
query._make_order("c", "DESCENDING"),
query._make_order("__name__", "DESCENDING"),
]
assert normalized_w_snapshot == expected_w_snapshot


def test_basequery__normalize_cursor_none():
query = _make_base_query(mock.sentinel.parent)
assert query._normalize_cursor(None, query._orders) is None
Expand Down

0 comments on commit dbe8ef7

Please sign in to comment.