checks for generators in database functions (#11564)

A couple of safety-checks to hopefully stop people doing what I just did, and create a storage
function which only works the first time it is called (and not when it is re-run due to a database
concurrency error or similar).
This commit is contained in:
Richard van der Hoff 2021-12-13 19:01:27 +00:00 committed by GitHub
parent eb39da6782
commit ff6fd52160
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 7 deletions

1
changelog.d/11564.misc Normal file
View file

@ -0,0 +1 @@
Add some safety checks that storage functions are used correctly.

View file

@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import time import time
import types
from collections import defaultdict from collections import defaultdict
from sys import intern from sys import intern
from time import monotonic as monotonic_time from time import monotonic as monotonic_time
@ -526,6 +528,12 @@ class DatabasePool:
the function will correctly handle being aborted and retried half way the function will correctly handle being aborted and retried half way
through its execution. through its execution.
Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
since they could be evaluated multiple times (which would produce an empty
result on the second or subsequent evaluation). Likewise, the closure of `func`
must not reference any generators. This method attempts to detect such usage
and will log an error.
Args: Args:
conn conn
desc desc
@ -536,6 +544,39 @@ class DatabasePool:
**kwargs **kwargs
""" """
# Robustness check: ensure that none of the arguments are generators, since that
# will fail if we have to repeat the transaction.
# For now, we just log an error, and hope that it works on the first attempt.
# TODO: raise an exception.
for i, arg in enumerate(args):
if inspect.isgenerator(arg):
logger.error(
"Programming error: generator passed to new_transaction as "
"argument %i to function %s",
i,
func,
)
for name, val in kwargs.items():
if inspect.isgenerator(val):
logger.error(
"Programming error: generator passed to new_transaction as "
"argument %s to function %s",
name,
func,
)
# also check variables referenced in func's closure
if inspect.isfunction(func):
f = cast(types.FunctionType, func)
if f.__closure__:
for i, cell in enumerate(f.__closure__):
if inspect.isgenerator(cell.cell_contents):
logger.error(
"Programming error: function %s references generator %s "
"via its closure",
f,
f.__code__.co_freevars[i],
)
start = monotonic_time() start = monotonic_time()
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -1226,9 +1267,9 @@ class DatabasePool:
self, self,
table: str, table: str,
key_names: Collection[str], key_names: Collection[str],
key_values: Collection[Iterable[Any]], key_values: Collection[Collection[Any]],
value_names: Collection[str], value_names: Collection[str],
value_values: Iterable[Iterable[Any]], value_values: Collection[Collection[Any]],
desc: str, desc: str,
) -> None: ) -> None:
""" """
@ -1920,7 +1961,7 @@ class DatabasePool:
self, self,
table: str, table: str,
column: str, column: str,
iterable: Iterable[Any], iterable: Collection[Any],
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
desc: str, desc: str,
) -> int: ) -> int:
@ -1931,7 +1972,8 @@ class DatabasePool:
Args: Args:
table: string giving the table name table: string giving the table name
column: column name to test for inclusion against `iterable` column: column name to test for inclusion against `iterable`
iterable: list iterable: list of values to match against `column`. NB cannot be a generator
as it may be evaluated multiple times.
keyvalues: dict of column names and values to select the rows with keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics

View file

@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
""" """
# Add user entries to the table, updating the presence_stream_id column if the user already # Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table. # exists in the table.
presence_stream_id = self._presence_id_gen.get_current_token()
await self.db_pool.simple_upsert_many( await self.db_pool.simple_upsert_many(
table="users_to_send_full_presence_to", table="users_to_send_full_presence_to",
key_names=("user_id",), key_names=("user_id",),
@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
# devices at different times, each device will receive full presence once - when # devices at different times, each device will receive full presence once - when
# the presence stream ID in their sync token is less than the one in the table # the presence stream ID in their sync token is less than the one in the table
# for their user ID. # for their user ID.
value_values=( value_values=[(presence_stream_id,) for _ in user_ids],
(self._presence_id_gen.get_current_token(),) for _ in user_ids
),
desc="add_users_to_send_full_presence_to", desc="add_users_to_send_full_presence_to",
) )