checks for generators in database functions (#11564)

A couple of safety-checks to hopefully stop people doing what I just did, and create a storage
function which only works the first time it is called (and not when it is re-run due to a database
concurrency error or similar).
This commit is contained in:
Richard van der Hoff 2021-12-13 19:01:27 +00:00 committed by GitHub
parent eb39da6782
commit ff6fd52160
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 7 deletions

1
changelog.d/11564.misc Normal file
View file

@ -0,0 +1 @@
Add some safety checks that storage functions are used correctly.

View file

@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import time
import types
from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
@ -526,6 +528,12 @@ class DatabasePool:
the function will correctly handle being aborted and retried half way
through its execution.
Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
since they could be evaluated multiple times (which would produce an empty
result on the second or subsequent evaluation). Likewise, the closure of `func`
must not reference any generators. This method attempts to detect such usage
and will log an error.
Args:
conn
desc
@ -536,6 +544,39 @@ class DatabasePool:
**kwargs
"""
# Robustness check: ensure that none of the arguments are generators, since that
# will fail if we have to repeat the transaction.
# For now, we just log an error, and hope that it works on the first attempt.
# TODO: raise an exception.
for i, arg in enumerate(args):
if inspect.isgenerator(arg):
logger.error(
"Programming error: generator passed to new_transaction as "
"argument %i to function %s",
i,
func,
)
for name, val in kwargs.items():
if inspect.isgenerator(val):
logger.error(
"Programming error: generator passed to new_transaction as "
"argument %s to function %s",
name,
func,
)
# also check variables referenced in func's closure
if inspect.isfunction(func):
f = cast(types.FunctionType, func)
if f.__closure__:
for i, cell in enumerate(f.__closure__):
if inspect.isgenerator(cell.cell_contents):
logger.error(
"Programming error: function %s references generator %s "
"via its closure",
f,
f.__code__.co_freevars[i],
)
start = monotonic_time()
txn_id = self._TXN_ID
@ -1226,9 +1267,9 @@ class DatabasePool:
self,
table: str,
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
key_values: Collection[Collection[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
value_values: Collection[Collection[Any]],
desc: str,
) -> None:
"""
@ -1920,7 +1961,7 @@ class DatabasePool:
self,
table: str,
column: str,
iterable: Iterable[Any],
iterable: Collection[Any],
keyvalues: Dict[str, Any],
desc: str,
) -> int:
@ -1931,7 +1972,8 @@ class DatabasePool:
Args:
table: string giving the table name
column: column name to test for inclusion against `iterable`
iterable: list
iterable: list of values to match against `column`. NB cannot be a generator
as it may be evaluated multiple times.
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics

View file

@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
"""
# Add user entries to the table, updating the presence_stream_id column if the user already
# exists in the table.
presence_stream_id = self._presence_id_gen.get_current_token()
await self.db_pool.simple_upsert_many(
table="users_to_send_full_presence_to",
key_names=("user_id",),
@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
# devices at different times, each device will receive full presence once - when
# the presence stream ID in their sync token is less than the one in the table
# for their user ID.
value_values=(
(self._presence_id_gen.get_current_token(),) for _ in user_ids
),
value_values=[(presence_stream_id,) for _ in user_ids],
desc="add_users_to_send_full_presence_to",
)