mirror of
https://github.com/element-hq/synapse.git
synced 2024-12-14 11:57:44 +00:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/createroom_content
This commit is contained in:
commit
d8a6c734fa
73 changed files with 1778 additions and 822 deletions
|
@ -5,7 +5,8 @@ To use it, first install prometheus by following the instructions at
|
|||
|
||||
http://prometheus.io/
|
||||
|
||||
Then add a new job to the main prometheus.conf file:
|
||||
### for Prometheus v1
|
||||
Add a new job to the main prometheus.conf file:
|
||||
|
||||
job: {
|
||||
name: "synapse"
|
||||
|
@ -15,6 +16,22 @@ Then add a new job to the main prometheus.conf file:
|
|||
}
|
||||
}
|
||||
|
||||
### for Prometheus v2
|
||||
Add a new job to the main prometheus.yml file:
|
||||
|
||||
- job_name: "synapse"
|
||||
metrics_path: "/_synapse/metrics"
|
||||
# when endpoint uses https:
|
||||
scheme: "https"
|
||||
|
||||
static_configs:
|
||||
- targets: ['SERVER.LOCATION:PORT']
|
||||
|
||||
To use `synapse.rules` add
|
||||
|
||||
rule_files:
|
||||
- "/PATH/TO/synapse-v2.rules"
|
||||
|
||||
Metrics are disabled by default when running synapse; they must be enabled
|
||||
with the 'enable-metrics' option, either in the synapse config file or as a
|
||||
command-line option.
|
||||
|
|
60
contrib/prometheus/synapse-v2.rules
Normal file
60
contrib/prometheus/synapse-v2.rules
Normal file
|
@ -0,0 +1,60 @@
|
|||
groups:
|
||||
- name: synapse
|
||||
rules:
|
||||
- record: "synapse_federation_transaction_queue_pendingEdus:total"
|
||||
expr: "sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)"
|
||||
- record: "synapse_federation_transaction_queue_pendingPdus:total"
|
||||
expr: "sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)"
|
||||
- record: 'synapse_http_server_requests:method'
|
||||
labels:
|
||||
servlet: ""
|
||||
expr: "sum(synapse_http_server_requests) by (method)"
|
||||
- record: 'synapse_http_server_requests:servlet'
|
||||
labels:
|
||||
method: ""
|
||||
expr: 'sum(synapse_http_server_requests) by (servlet)'
|
||||
|
||||
- record: 'synapse_http_server_requests:total'
|
||||
labels:
|
||||
servlet: ""
|
||||
expr: 'sum(synapse_http_server_requests:by_method) by (servlet)'
|
||||
|
||||
- record: 'synapse_cache:hit_ratio_5m'
|
||||
expr: 'rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])'
|
||||
- record: 'synapse_cache:hit_ratio_30s'
|
||||
expr: 'rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])'
|
||||
|
||||
- record: 'synapse_federation_client_sent'
|
||||
labels:
|
||||
type: "EDU"
|
||||
expr: 'synapse_federation_client_sent_edus + 0'
|
||||
- record: 'synapse_federation_client_sent'
|
||||
labels:
|
||||
type: "PDU"
|
||||
expr: 'synapse_federation_client_sent_pdu_destinations:count + 0'
|
||||
- record: 'synapse_federation_client_sent'
|
||||
labels:
|
||||
type: "Query"
|
||||
expr: 'sum(synapse_federation_client_sent_queries) by (job)'
|
||||
|
||||
- record: 'synapse_federation_server_received'
|
||||
labels:
|
||||
type: "EDU"
|
||||
expr: 'synapse_federation_server_received_edus + 0'
|
||||
- record: 'synapse_federation_server_received'
|
||||
labels:
|
||||
type: "PDU"
|
||||
expr: 'synapse_federation_server_received_pdus + 0'
|
||||
- record: 'synapse_federation_server_received'
|
||||
labels:
|
||||
type: "Query"
|
||||
expr: 'sum(synapse_federation_server_received_queries) by (job)'
|
||||
|
||||
- record: 'synapse_federation_transaction_queue_pending'
|
||||
labels:
|
||||
type: "EDU"
|
||||
expr: 'synapse_federation_transaction_queue_pending_edus + 0'
|
||||
- record: 'synapse_federation_transaction_queue_pending'
|
||||
labels:
|
||||
type: "PDU"
|
||||
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
|
|
@ -298,10 +298,6 @@ It can be used like this:
|
|||
# this will now be logged against the request context
|
||||
logger.debug("Request handling complete")
|
||||
|
||||
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
|
||||
but the fact that it does ``preserve_context_over_deferred`` on its results
|
||||
means that its use is fraught with difficulty.
|
||||
|
||||
Passing synapse deferreds into third-party functions
|
||||
----------------------------------------------------
|
||||
|
||||
|
|
157
docs/workers.rst
157
docs/workers.rst
|
@ -1,11 +1,15 @@
|
|||
Scaling synapse via workers
|
||||
---------------------------
|
||||
===========================
|
||||
|
||||
Synapse has experimental support for splitting out functionality into
|
||||
multiple separate python processes, helping greatly with scalability. These
|
||||
processes are called 'workers', and are (eventually) intended to scale
|
||||
horizontally independently.
|
||||
|
||||
All of the below is highly experimental and subject to change as Synapse evolves,
|
||||
but documenting it here to help folks needing highly scalable Synapses similar
|
||||
to the one running matrix.org!
|
||||
|
||||
All processes continue to share the same database instance, and as such, workers
|
||||
only work with postgres based synapse deployments (sharing a single sqlite
|
||||
across multiple processes is a recipe for disaster, plus you should be using
|
||||
|
@ -16,6 +20,16 @@ TCP protocol called 'replication' - analogous to MySQL or Postgres style
|
|||
database replication; feeding a stream of relevant data to the workers so they
|
||||
can be kept in sync with the main synapse process and database state.
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
To make effective use of the workers, you will need to configure an HTTP
|
||||
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
|
||||
the correct worker, or to the main synapse instance. Note that this includes
|
||||
requests made to the federation port. The caveats regarding running a
|
||||
reverse-proxy on the federation port still apply (see
|
||||
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
|
||||
|
||||
To enable workers, you need to add a replication listener to the master synapse, e.g.::
|
||||
|
||||
listeners:
|
||||
|
@ -27,26 +41,19 @@ Under **no circumstances** should this replication API listener be exposed to th
|
|||
public internet; it currently implements no authentication whatsoever and is
|
||||
unencrypted.
|
||||
|
||||
You then create a set of configs for the various worker processes. These should be
|
||||
worker configuration files should be stored in a dedicated subdirectory, to allow
|
||||
synctl to manipulate them.
|
||||
|
||||
The current available worker applications are:
|
||||
* synapse.app.pusher - handles sending push notifications to sygnal and email
|
||||
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
|
||||
* synapse.app.appservice - handles output traffic to Application Services
|
||||
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
|
||||
* synapse.app.media_repository - handles the media repository.
|
||||
* synapse.app.client_reader - handles client API endpoints like /publicRooms
|
||||
You then create a set of configs for the various worker processes. These
|
||||
should be worker configuration files, and should be stored in a dedicated
|
||||
subdirectory, to allow synctl to manipulate them.
|
||||
|
||||
Each worker configuration file inherits the configuration of the main homeserver
|
||||
configuration file. You can then override configuration specific to that worker,
|
||||
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
|
||||
You should minimise the number of overrides though to maintain a usable config.
|
||||
|
||||
You must specify the type of worker application (worker_app) and the replication
|
||||
endpoint that it's talking to on the main synapse process (worker_replication_host
|
||||
and worker_replication_port).
|
||||
You must specify the type of worker application (``worker_app``). The currently
|
||||
available worker applications are listed below. You must also specify the
|
||||
replication endpoint that it's talking to on the main synapse process
|
||||
(``worker_replication_host`` and ``worker_replication_port``).
|
||||
|
||||
For instance::
|
||||
|
||||
|
@ -68,11 +75,11 @@ For instance::
|
|||
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
|
||||
|
||||
...is a full configuration for a synchrotron worker instance, which will expose a
|
||||
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
|
||||
plain HTTP ``/sync`` endpoint on port 8083 separately from the ``/sync`` endpoint provided
|
||||
by the main synapse.
|
||||
|
||||
Obviously you should configure your loadbalancer to route the /sync endpoint to
|
||||
the synchrotron instance(s) in this instance.
|
||||
Obviously you should configure your reverse-proxy to route the relevant
|
||||
endpoints to the worker (``localhost:8083`` in the above example).
|
||||
|
||||
Finally, to actually run your worker-based synapse, you must pass synctl the -a
|
||||
commandline option to tell it to operate on all the worker configurations found
|
||||
|
@ -89,6 +96,114 @@ To manipulate a specific worker, you pass the -w option to synctl::
|
|||
|
||||
synctl -w $CONFIG/workers/synchrotron.yaml restart
|
||||
|
||||
All of the above is highly experimental and subject to change as Synapse evolves,
|
||||
but documenting it here to help folks needing highly scalable Synapses similar
|
||||
to the one running matrix.org!
|
||||
|
||||
Available worker applications
|
||||
-----------------------------
|
||||
|
||||
``synapse.app.pusher``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles sending push notifications to sygnal and email. Doesn't handle any
|
||||
REST endpoints itself, but you should set ``start_pushers: False`` in the
|
||||
shared configuration file to stop the main synapse sending these notifications.
|
||||
|
||||
Note this worker cannot be load-balanced: only one instance should be active.
|
||||
|
||||
``synapse.app.synchrotron``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The synchrotron handles ``sync`` requests from clients. In particular, it can
|
||||
handle REST endpoints matching the following regular expressions::
|
||||
|
||||
^/_matrix/client/(v2_alpha|r0)/sync$
|
||||
^/_matrix/client/(api/v1|v2_alpha|r0)/events$
|
||||
^/_matrix/client/(api/v1|r0)/initialSync$
|
||||
^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
|
||||
|
||||
The above endpoints should all be routed to the synchrotron worker by the
|
||||
reverse-proxy configuration.
|
||||
|
||||
It is possible to run multiple instances of the synchrotron to scale
|
||||
horizontally. In this case the reverse-proxy should be configured to
|
||||
load-balance across the instances, though it will be more efficient if all
|
||||
requests from a particular user are routed to a single instance. Extracting
|
||||
a userid from the access token is currently left as an exercise for the reader.
|
||||
|
||||
``synapse.app.appservice``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles sending output traffic to Application Services. Doesn't handle any
|
||||
REST endpoints itself, but you should set ``notify_appservices: False`` in the
|
||||
shared configuration file to stop the main synapse sending these notifications.
|
||||
|
||||
Note this worker cannot be load-balanced: only one instance should be active.
|
||||
|
||||
``synapse.app.federation_reader``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles a subset of federation endpoints. In particular, it can handle REST
|
||||
endpoints matching the following regular expressions::
|
||||
|
||||
^/_matrix/federation/v1/event/
|
||||
^/_matrix/federation/v1/state/
|
||||
^/_matrix/federation/v1/state_ids/
|
||||
^/_matrix/federation/v1/backfill/
|
||||
^/_matrix/federation/v1/get_missing_events/
|
||||
^/_matrix/federation/v1/publicRooms
|
||||
|
||||
The above endpoints should all be routed to the federation_reader worker by the
|
||||
reverse-proxy configuration.
|
||||
|
||||
``synapse.app.federation_sender``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles sending federation traffic to other servers. Doesn't handle any
|
||||
REST endpoints itself, but you should set ``send_federation: False`` in the
|
||||
shared configuration file to stop the main synapse sending this traffic.
|
||||
|
||||
Note this worker cannot be load-balanced: only one instance should be active.
|
||||
|
||||
``synapse.app.media_repository``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles the media repository. It can handle all endpoints starting with::
|
||||
|
||||
/_matrix/media/
|
||||
|
||||
You should also set ``enable_media_repo: False`` in the shared configuration
|
||||
file to stop the main synapse running background jobs related to managing the
|
||||
media repository.
|
||||
|
||||
Note this worker cannot be load-balanced: only one instance should be active.
|
||||
|
||||
``synapse.app.client_reader``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles client API endpoints. It can handle REST endpoints matching the
|
||||
following regular expressions::
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
|
||||
|
||||
``synapse.app.user_dir``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles searches in the user directory. It can handle REST endpoints matching
|
||||
the following regular expressions::
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
|
||||
|
||||
``synapse.app.frontend_proxy``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Proxies some frequently-requested client endpoints to add caching and remove
|
||||
load from the main synapse. It can handle REST endpoints matching the following
|
||||
regular expressions::
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
|
||||
|
||||
It will proxy any requests it cannot handle to the main synapse instance. It
|
||||
must therefore be configured with the location of the main instance, via
|
||||
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
|
||||
file. For example::
|
||||
|
||||
worker_main_http_uri: http://127.0.0.1:8008
|
||||
|
|
|
@ -123,15 +123,25 @@ def lookup(destination, path):
|
|||
except:
|
||||
return "https://%s:%d%s" % (destination, 8448, path)
|
||||
|
||||
def get_json(origin_name, origin_key, destination, path):
|
||||
request_json = {
|
||||
"method": "GET",
|
||||
|
||||
def request_json(method, origin_name, origin_key, destination, path, content):
|
||||
if method is None:
|
||||
if content is None:
|
||||
method = "GET"
|
||||
else:
|
||||
method = "POST"
|
||||
|
||||
json_to_sign = {
|
||||
"method": method,
|
||||
"uri": path,
|
||||
"origin": origin_name,
|
||||
"destination": destination,
|
||||
}
|
||||
|
||||
signed_json = sign_json(request_json, origin_key, origin_name)
|
||||
if content is not None:
|
||||
json_to_sign["content"] = json.loads(content)
|
||||
|
||||
signed_json = sign_json(json_to_sign, origin_key, origin_name)
|
||||
|
||||
authorization_headers = []
|
||||
|
||||
|
@ -145,10 +155,12 @@ def get_json(origin_name, origin_key, destination, path):
|
|||
dest = lookup(destination, path)
|
||||
print ("Requesting %s" % dest, file=sys.stderr)
|
||||
|
||||
result = requests.get(
|
||||
dest,
|
||||
result = requests.request(
|
||||
method=method,
|
||||
url=dest,
|
||||
headers={"Authorization": authorization_headers[0]},
|
||||
verify=False,
|
||||
data=content,
|
||||
)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
return result.json()
|
||||
|
@ -186,6 +198,17 @@ def main():
|
|||
"connect appropriately.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-X", "--method",
|
||||
help="HTTP method to use for the request. Defaults to GET if --data is"
|
||||
"unspecified, POST if it is."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--body",
|
||||
help="Data to send as the body of the HTTP request"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"path",
|
||||
help="request path. We will add '/_matrix/federation/v1/' to this."
|
||||
|
@ -199,8 +222,11 @@ def main():
|
|||
with open(args.signing_key_path) as f:
|
||||
key = read_signing_keys(f)[0]
|
||||
|
||||
result = get_json(
|
||||
args.server_name, key, args.destination, "/_matrix/federation/v1/" + args.path
|
||||
result = request_json(
|
||||
args.method,
|
||||
args.server_name, key, args.destination,
|
||||
"/_matrix/federation/v1/" + args.path,
|
||||
content=args.body,
|
||||
)
|
||||
|
||||
json.dump(result, sys.stdout)
|
||||
|
|
45
scripts/sync_room_to_group.pl
Executable file
45
scripts/sync_room_to_group.pl
Executable file
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env perl
|
||||
|
||||
use strict;
|
||||
use warnings;
|
||||
|
||||
use JSON::XS;
|
||||
use LWP::UserAgent;
|
||||
use URI::Escape;
|
||||
|
||||
if (@ARGV < 4) {
|
||||
die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
|
||||
}
|
||||
|
||||
my ($hs, $access_token, $room_id, $group_id) = @ARGV;
|
||||
my $ua = LWP::UserAgent->new();
|
||||
$ua->timeout(10);
|
||||
|
||||
if ($room_id =~ /^#/) {
|
||||
$room_id = uri_escape($room_id);
|
||||
$room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
|
||||
}
|
||||
|
||||
my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
|
||||
my $group_users = [
|
||||
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
|
||||
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
|
||||
];
|
||||
|
||||
die "refusing to sync from empty room" unless (@$room_users);
|
||||
die "refusing to sync to empty group" unless (@$group_users);
|
||||
|
||||
my $diff = {};
|
||||
foreach my $user (@$room_users) { $diff->{$user}++ }
|
||||
foreach my $user (@$group_users) { $diff->{$user}-- }
|
||||
|
||||
foreach my $user (keys %$diff) {
|
||||
if ($diff->{$user} == 1) {
|
||||
warn "inviting $user";
|
||||
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
|
||||
}
|
||||
elsif ($diff->{$user} == -1) {
|
||||
warn "removing $user";
|
||||
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
|
||||
}
|
||||
}
|
|
@ -270,7 +270,11 @@ class Auth(object):
|
|||
rights (str): The operation being performed; the access token must
|
||||
allow this.
|
||||
Returns:
|
||||
dict : dict that includes the user and the ID of their access token.
|
||||
Deferred[dict]: dict that includes:
|
||||
`user` (UserID)
|
||||
`is_guest` (bool)
|
||||
`token_id` (int|None): access token id. May be None if guest
|
||||
`device_id` (str|None): device corresponding to access token
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
|
|
|
@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
|
|||
pass
|
||||
|
||||
|
||||
class InteractiveAuthIncompleteError(Exception):
|
||||
"""An error raised when UI auth is not yet complete
|
||||
|
||||
(This indicates we should return a 401 with 'result' as the body)
|
||||
|
||||
Attributes:
|
||||
result (dict): the server response to the request, which should be
|
||||
passed back to the client
|
||||
"""
|
||||
def __init__(self, result):
|
||||
super(InteractiveAuthIncompleteError, self).__init__(
|
||||
"Interactive auth not yet complete",
|
||||
)
|
||||
self.result = result
|
||||
|
||||
|
||||
class UnrecognizedRequestError(SynapseError):
|
||||
"""An error indicating we don't understand the request you're trying to make"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -43,7 +43,6 @@ from synapse.rest import ClientRestResource
|
|||
from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import are_all_users_on_domain
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
||||
|
@ -195,14 +194,19 @@ class SynapseHomeServer(HomeServer):
|
|||
})
|
||||
|
||||
if name in ["media", "federation", "client"]:
|
||||
media_repo = MediaRepositoryResource(self)
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
if self.get_config().enable_media_repo:
|
||||
media_repo = self.get_media_repository_resource()
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
elif name == "media":
|
||||
raise ConfigError(
|
||||
"'media' resource conflicts with enable_media_repo=False",
|
||||
)
|
||||
|
||||
if name in ["keys", "federation"]:
|
||||
resources.update({
|
||||
|
|
|
@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
|||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.media_repository import MediaRepositoryStore
|
||||
|
@ -89,7 +88,7 @@ class MediaRepositoryServer(HomeServer):
|
|||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "media":
|
||||
media_repo = MediaRepositoryResource(self)
|
||||
media_repo = self.get_media_repository_resource()
|
||||
resources.update({
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
|
@ -151,6 +150,13 @@ def start(config_options):
|
|||
|
||||
assert config.worker_app == "synapse.app.media_repository"
|
||||
|
||||
if config.enable_media_repo:
|
||||
_base.quit_with_error(
|
||||
"enable_media_repo must be disabled in the main synapse process\n"
|
||||
"before the media repo can be run in a separate worker.\n"
|
||||
"Please add ``enable_media_repo: false`` to the main config\n"
|
||||
)
|
||||
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
|
|
@ -340,11 +340,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
|||
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
# NB this is a SynchrotronPresence, not a normal PresenceHandler
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.presence_handler.sync_callback = self.send_user_sync
|
||||
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.types import GroupID, get_domain_from_id
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -81,12 +82,13 @@ class ApplicationService(object):
|
|||
# values.
|
||||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||
|
||||
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
||||
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
|
||||
sender=None, id=None, protocols=None, rate_limited=True):
|
||||
self.token = token
|
||||
self.url = url
|
||||
self.hs_token = hs_token
|
||||
self.sender = sender
|
||||
self.server_name = hostname
|
||||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.id = id
|
||||
|
||||
|
@ -125,6 +127,24 @@ class ApplicationService(object):
|
|||
raise ValueError(
|
||||
"Expected bool for 'exclusive' in ns '%s'" % ns
|
||||
)
|
||||
group_id = regex_obj.get("group_id")
|
||||
if group_id:
|
||||
if not isinstance(group_id, str):
|
||||
raise ValueError(
|
||||
"Expected string for 'group_id' in ns '%s'" % ns
|
||||
)
|
||||
try:
|
||||
GroupID.from_string(group_id)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Expected valid group ID for 'group_id' in ns '%s'" % ns
|
||||
)
|
||||
|
||||
if get_domain_from_id(group_id) != self.server_name:
|
||||
raise ValueError(
|
||||
"Expected 'group_id' to be this host in ns '%s'" % ns
|
||||
)
|
||||
|
||||
regex = regex_obj.get("regex")
|
||||
if isinstance(regex, basestring):
|
||||
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
||||
|
@ -251,6 +271,21 @@ class ApplicationService(object):
|
|||
if regex_obj["exclusive"]
|
||||
]
|
||||
|
||||
def get_groups_for_user(self, user_id):
|
||||
"""Get the groups that this user is associated with by this AS
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user.
|
||||
|
||||
Returns:
|
||||
iterable[str]: an iterable that yields group_id strings.
|
||||
"""
|
||||
return (
|
||||
regex_obj["group_id"]
|
||||
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
|
||||
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
|
||||
)
|
||||
|
||||
def is_rate_limited(self):
|
||||
return self.rate_limited
|
||||
|
||||
|
|
|
@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||
)
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
hostname=hostname,
|
||||
url=as_info["url"],
|
||||
namespaces=as_info["namespaces"],
|
||||
hs_token=as_info["hs_token"],
|
||||
|
|
|
@ -36,6 +36,7 @@ from .workers import WorkerConfig
|
|||
from .push import PushConfig
|
||||
from .spam_checker import SpamCheckerConfig
|
||||
from .groups import GroupsConfig
|
||||
from .user_directory import UserDirectoryConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
|
@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
|||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, PasswordConfig, EmailConfig,
|
||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
|
||||
SpamCheckerConfig, GroupsConfig,):
|
||||
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,28 +19,43 @@ from ._base import Config
|
|||
|
||||
class PushConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.push_redact_content = False
|
||||
push_config = config.get("push", {})
|
||||
self.push_include_content = push_config.get("include_content", True)
|
||||
|
||||
# There was a a 'redact_content' setting but mistakenly read from the
|
||||
# 'email'section'. Check for the flag in the 'push' section, and log,
|
||||
# but do not honour it to avoid nasty surprises when people upgrade.
|
||||
if push_config.get("redact_content") is not None:
|
||||
print(
|
||||
"The push.redact_content content option has never worked. "
|
||||
"Please set push.include_content if you want this behaviour"
|
||||
)
|
||||
|
||||
# Now check for the one in the 'email' section and honour it,
|
||||
# with a warning.
|
||||
push_config = config.get("email", {})
|
||||
self.push_redact_content = push_config.get("redact_content", False)
|
||||
redact_content = push_config.get("redact_content")
|
||||
if redact_content is not None:
|
||||
print(
|
||||
"The 'email.redact_content' option is deprecated: "
|
||||
"please set push.include_content instead"
|
||||
)
|
||||
self.push_include_content = not redact_content
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Control how push messages are sent to google/apple to notifications.
|
||||
# Normally every message said in a room with one or more people using
|
||||
# mobile devices will be posted to a push server hosted by matrix.org
|
||||
# which is registered with google and apple in order to allow push
|
||||
# notifications to be sent to these mobile devices.
|
||||
#
|
||||
# Setting redact_content to true will make the push messages contain no
|
||||
# message content which will provide increased privacy. This is a
|
||||
# temporary solution pending improvements to Android and iPhone apps
|
||||
# to get content from the app rather than the notification.
|
||||
#
|
||||
# Clients requesting push notifications can either have the body of
|
||||
# the message sent in the notification poke along with other details
|
||||
# like the sender, or just the event ID and room ID (`event_id_only`).
|
||||
# If clients choose the former, this option controls whether the
|
||||
# notification request includes the content of the event (other details
|
||||
# like the sender are still included). For `event_id_only` push, it
|
||||
# has no effect.
|
||||
|
||||
# For modern android devices the notification content will still appear
|
||||
# because it is loaded by the app. iPhone, however will send a
|
||||
# notification saying only that a message arrived and who it came from.
|
||||
#
|
||||
#push:
|
||||
# redact_content: false
|
||||
# include_content: true
|
||||
"""
|
||||
|
|
|
@ -41,6 +41,12 @@ class ServerConfig(Config):
|
|||
# false only if we are updating the user directory in a worker
|
||||
self.update_user_directory = config.get("update_user_directory", True)
|
||||
|
||||
# whether to enable the media repository endpoints. This should be set
|
||||
# to false if the media repository is running as a separate endpoint;
|
||||
# doing so ensures that we will not run cache cleanup jobs on the
|
||||
# master, potentially causing inconsistency.
|
||||
self.enable_media_repo = config.get("enable_media_repo", True)
|
||||
|
||||
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
|
||||
|
||||
# Whether we should block invites sent to users on this server
|
||||
|
|
44
synapse/config/user_directory.py
Normal file
44
synapse/config/user_directory.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class UserDirectoryConfig(Config):
|
||||
"""User Directory Configuration
|
||||
Configuration for the behaviour of the /user_directory API
|
||||
"""
|
||||
|
||||
def read_config(self, config):
|
||||
self.user_directory_search_all_users = False
|
||||
user_directory_config = config.get("user_directory", None)
|
||||
if user_directory_config:
|
||||
self.user_directory_search_all_users = (
|
||||
user_directory_config.get("search_all_users", False)
|
||||
)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# User Directory configuration
|
||||
#
|
||||
# 'search_all_users' defines whether to search all users visible to your HS
|
||||
# when searching the user directory, rather than limiting to users visible
|
||||
# in public rooms. Defaults to false. If you set it True, you'll have to run
|
||||
# UPDATE user_directory_stream_pos SET stream_id = NULL;
|
||||
# on your database to tell it to rebuild the user_directory search indexes.
|
||||
#
|
||||
#user_directory:
|
||||
# search_all_users: false
|
||||
"""
|
|
@ -32,15 +32,22 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
|||
"""Check whether the hash for this PDU matches the contents"""
|
||||
name, expected_hash = compute_content_hash(event, hash_algorithm)
|
||||
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
|
||||
if name not in event.hashes:
|
||||
|
||||
# some malformed events lack a 'hashes'. Protect against it being missing
|
||||
# or a weird type by basically treating it the same as an unhashed event.
|
||||
hashes = event.get("hashes")
|
||||
if not isinstance(hashes, dict):
|
||||
raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
|
||||
|
||||
if name not in hashes:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Algorithm %s not in hashes %s" % (
|
||||
name, list(event.hashes),
|
||||
name, list(hashes),
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
message_hash_base64 = event.hashes[name]
|
||||
message_hash_base64 = hashes[name]
|
||||
try:
|
||||
message_hash_bytes = decode_base64(message_hash_base64)
|
||||
except Exception:
|
||||
|
|
|
@ -25,7 +25,7 @@ from synapse.api.errors import (
|
|||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.events import FrozenEvent, builder
|
||||
import synapse.metrics
|
||||
|
||||
|
@ -420,7 +420,7 @@ class FederationClient(FederationBase):
|
|||
for e_id in batch
|
||||
]
|
||||
|
||||
res = yield preserve_context_over_deferred(
|
||||
res = yield make_deferred_yieldable(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
for success, result in res:
|
||||
|
|
|
@ -20,7 +20,7 @@ from .persistence import TransactionActions
|
|||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util import logcontext
|
||||
from synapse.util import logcontext, PreserveLoggingContext
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
from synapse.util.metrics import measure_func
|
||||
|
@ -146,7 +146,6 @@ class TransactionQueue(object):
|
|||
else:
|
||||
return not destination.startswith("localhost")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_new_events(self, current_id):
|
||||
"""This gets called when we have some new events we might want to
|
||||
send out to other servers.
|
||||
|
@ -156,6 +155,13 @@ class TransactionQueue(object):
|
|||
if self._is_processing:
|
||||
return
|
||||
|
||||
# fire off a processing loop in the background. It's likely it will
|
||||
# outlast the current request, so run it in the sentinel logcontext.
|
||||
with PreserveLoggingContext():
|
||||
self._process_event_queue_loop()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_event_queue_loop(self):
|
||||
try:
|
||||
self._is_processing = True
|
||||
while True:
|
||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -159,7 +159,7 @@ class ApplicationServicesHandler(object):
|
|||
def query_3pe(self, kind, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield preserve_context_over_deferred(defer.DeferredList([
|
||||
results = yield make_deferred_yieldable(defer.DeferredList([
|
||||
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
|
||||
for service in services
|
||||
], consumeErrors=True))
|
||||
|
|
|
@ -17,7 +17,10 @@ from twisted.internet import defer
|
|||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||
from synapse.api.errors import (
|
||||
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
@ -46,7 +49,6 @@ class AuthHandler(BaseHandler):
|
|||
"""
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
self.checkers = {
|
||||
LoginType.PASSWORD: self._check_password_auth,
|
||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||
LoginType.MSISDN: self._check_msisdn,
|
||||
|
@ -75,15 +77,76 @@ class AuthHandler(BaseHandler):
|
|||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
|
||||
login_types = set()
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
# type in the list. (NB that the spec doesn't require us to do so and
|
||||
# clients which favour types that they don't understand over those that
|
||||
# they do are technically broken)
|
||||
login_types = []
|
||||
if self._password_enabled:
|
||||
login_types.add(LoginType.PASSWORD)
|
||||
login_types.append(LoginType.PASSWORD)
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "get_supported_login_types"):
|
||||
login_types.update(
|
||||
provider.get_supported_login_types().keys()
|
||||
)
|
||||
self._supported_login_types = frozenset(login_types)
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t not in login_types:
|
||||
login_types.append(t)
|
||||
self._supported_login_types = login_types
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
||||
"""
|
||||
Checks that the user is who they claim to be, via a UI auth.
|
||||
|
||||
This is used for things like device deletion and password reset where
|
||||
the user already has a valid access token, but we want to double-check
|
||||
that it isn't stolen by re-authenticating them.
|
||||
|
||||
Args:
|
||||
requester (Requester): The user, as given by the access token
|
||||
|
||||
request_body (dict): The body of the request sent by the client
|
||||
|
||||
clientip (str): The IP address of the client.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[dict]: the parameters for this request (which may
|
||||
have been given only in a previous call).
|
||||
|
||||
Raises:
|
||||
InteractiveAuthIncompleteError if the client has not yet completed
|
||||
any of the permitted login flows
|
||||
|
||||
AuthError if the client has completed a login flow, and it gives
|
||||
a different user to `requester`
|
||||
"""
|
||||
|
||||
# build a list of supported flows
|
||||
flows = [
|
||||
[login_type] for login_type in self._supported_login_types
|
||||
]
|
||||
|
||||
result, params, _ = yield self.check_auth(
|
||||
flows, request_body, clientip,
|
||||
)
|
||||
|
||||
# find the completed login type
|
||||
for login_type in self._supported_login_types:
|
||||
if login_type not in result:
|
||||
continue
|
||||
|
||||
user_id = result[login_type]
|
||||
break
|
||||
else:
|
||||
# this can't happen
|
||||
raise Exception(
|
||||
"check_auth returned True but no successful login type",
|
||||
)
|
||||
|
||||
# check that the UI auth matched the access token
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Invalid auth")
|
||||
|
||||
defer.returnValue(params)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
|
@ -95,26 +158,36 @@ class AuthHandler(BaseHandler):
|
|||
session with a map, which maps each auth-type (str) to the relevant
|
||||
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||
|
||||
If no auth flows have been completed successfully, raises an
|
||||
InteractiveAuthIncompleteError. To handle this, you can use
|
||||
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
|
||||
decorator.
|
||||
|
||||
Args:
|
||||
flows (list): A list of login flows. Each flow is an ordered list of
|
||||
strings representing auth-types. At least one full
|
||||
flow must be completed in order for auth to be successful.
|
||||
|
||||
clientdict: The dictionary from the client root level, not the
|
||||
'auth' key: this method prompts for auth if none is sent.
|
||||
|
||||
clientip (str): The IP address of the client.
|
||||
|
||||
Returns:
|
||||
A tuple of (authed, dict, dict, session_id) where authed is true if
|
||||
the client has successfully completed an auth flow. If it is true
|
||||
the first dict contains the authenticated credentials of each stage.
|
||||
defer.Deferred[dict, dict, str]: a deferred tuple of
|
||||
(creds, params, session_id).
|
||||
|
||||
If authed is false, the first dictionary is the server response to
|
||||
the login request and should be passed back to the client.
|
||||
'creds' contains the authenticated credentials of each stage.
|
||||
|
||||
In either case, the second dict contains the parameters for this
|
||||
request (which may have been given only in a previous call).
|
||||
'params' contains the parameters for this request (which may
|
||||
have been given only in a previous call).
|
||||
|
||||
session_id is the ID of this session, either passed in by the client
|
||||
or assigned by the call to check_auth
|
||||
'session_id' is the ID of this session, either passed in by the
|
||||
client or assigned by this call
|
||||
|
||||
Raises:
|
||||
InteractiveAuthIncompleteError if the client has not yet completed
|
||||
all the stages in any of the permitted flows.
|
||||
"""
|
||||
|
||||
authdict = None
|
||||
|
@ -142,11 +215,8 @@ class AuthHandler(BaseHandler):
|
|||
clientdict = session['clientdict']
|
||||
|
||||
if not authdict:
|
||||
defer.returnValue(
|
||||
(
|
||||
False, self._auth_dict_for_flows(flows, session),
|
||||
clientdict, session['id']
|
||||
)
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session),
|
||||
)
|
||||
|
||||
if 'creds' not in session:
|
||||
|
@ -157,14 +227,12 @@ class AuthHandler(BaseHandler):
|
|||
errordict = {}
|
||||
if 'type' in authdict:
|
||||
login_type = authdict['type']
|
||||
if login_type not in self.checkers:
|
||||
raise LoginError(400, "", Codes.UNRECOGNIZED)
|
||||
try:
|
||||
result = yield self.checkers[login_type](authdict, clientip)
|
||||
result = yield self._check_auth_dict(authdict, clientip)
|
||||
if result:
|
||||
creds[login_type] = result
|
||||
self._save_session(session)
|
||||
except LoginError, e:
|
||||
except LoginError as e:
|
||||
if login_type == LoginType.EMAIL_IDENTITY:
|
||||
# riot used to have a bug where it would request a new
|
||||
# validation token (thus sending a new email) each time it
|
||||
|
@ -173,7 +241,7 @@ class AuthHandler(BaseHandler):
|
|||
#
|
||||
# Grandfather in the old behaviour for now to avoid
|
||||
# breaking old riot deployments.
|
||||
raise e
|
||||
raise
|
||||
|
||||
# this step failed. Merge the error dict into the response
|
||||
# so that the client can have another go.
|
||||
|
@ -190,12 +258,14 @@ class AuthHandler(BaseHandler):
|
|||
"Auth completed with creds: %r. Client dict has keys: %r",
|
||||
creds, clientdict.keys()
|
||||
)
|
||||
defer.returnValue((True, creds, clientdict, session['id']))
|
||||
defer.returnValue((creds, clientdict, session['id']))
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session)
|
||||
ret['completed'] = creds.keys()
|
||||
ret.update(errordict)
|
||||
defer.returnValue((False, ret, clientdict, session['id']))
|
||||
raise InteractiveAuthIncompleteError(
|
||||
ret,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||
|
@ -268,17 +338,35 @@ class AuthHandler(BaseHandler):
|
|||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
def _check_auth_dict(self, authdict, clientip):
|
||||
"""Attempt to validate the auth dict provided by a client
|
||||
|
||||
user_id = authdict["user"]
|
||||
password = authdict["password"]
|
||||
Args:
|
||||
authdict (object): auth dict provided by the client
|
||||
clientip (str): IP address of the client
|
||||
|
||||
(canonical_id, callback) = yield self.validate_login(user_id, {
|
||||
"type": LoginType.PASSWORD,
|
||||
"password": password,
|
||||
})
|
||||
Returns:
|
||||
Deferred: result of the stage verification.
|
||||
|
||||
Raises:
|
||||
StoreError if there was a problem accessing the database
|
||||
SynapseError if there was a problem with the request
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
login_type = authdict['type']
|
||||
checker = self.checkers.get(login_type)
|
||||
if checker is not None:
|
||||
res = yield checker(authdict, clientip)
|
||||
defer.returnValue(res)
|
||||
|
||||
# build a v1-login-style dict out of the authdict and fall back to the
|
||||
# v1 code
|
||||
user_id = authdict.get("user")
|
||||
|
||||
if user_id is None:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
||||
defer.returnValue(canonical_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -649,41 +737,6 @@ class AuthHandler(BaseHandler):
|
|||
except Exception:
|
||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = self.hash(newpassword)
|
||||
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
yield self.delete_access_tokens_for_user(
|
||||
user_id, except_token_id=except_access_token_id,
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deactivate_account(self, user_id):
|
||||
"""Deactivate a user's account
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user to be deactivated
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# FIXME: Theoretically there is a race here wherein user resets
|
||||
# password using threepid.
|
||||
yield self.delete_access_tokens_for_user(user_id)
|
||||
yield self.store.user_delete_threepids(user_id)
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_token(self, access_token):
|
||||
"""Invalidate a single access token
|
||||
|
@ -706,6 +759,12 @@ class AuthHandler(BaseHandler):
|
|||
access_token=access_token,
|
||||
)
|
||||
|
||||
# delete pushers associated with this access token
|
||||
if user_info["token_id"] is not None:
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
str(user_info["user"]), (user_info["token_id"], )
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
|
||||
device_id=None):
|
||||
|
@ -728,13 +787,18 @@ class AuthHandler(BaseHandler):
|
|||
# see if any of our auth providers want to know about this
|
||||
for provider in self.password_providers:
|
||||
if hasattr(provider, "on_logged_out"):
|
||||
for token, device_id in tokens_and_devices:
|
||||
for token, token_id, device_id in tokens_and_devices:
|
||||
yield provider.on_logged_out(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
access_token=token,
|
||||
)
|
||||
|
||||
# delete pushers associated with the access tokens
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||
user_id, (token_id for _, token_id, _ in tokens_and_devices),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_threepid(self, user_id, medium, address, validated_at):
|
||||
# 'Canonicalise' email addresses down to lower case.
|
||||
|
|
52
synapse/handlers/deactivate_account.py
Normal file
52
synapse/handlers/deactivate_account.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeactivateAccountHandler(BaseHandler):
|
||||
"""Handler which deals with deactivating user accounts."""
|
||||
def __init__(self, hs):
|
||||
super(DeactivateAccountHandler, self).__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deactivate_account(self, user_id):
|
||||
"""Deactivate a user's account
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user to be deactivated
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# FIXME: Theoretically there is a race here wherein user resets
|
||||
# password using threepid.
|
||||
|
||||
# first delete any devices belonging to the user, which will also
|
||||
# delete corresponding access tokens.
|
||||
yield self._device_handler.delete_all_devices_for_user(user_id)
|
||||
# then delete any remaining access tokens which weren't associated with
|
||||
# a device.
|
||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||
|
||||
yield self.store.user_delete_threepids(user_id)
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
|
@ -170,13 +170,31 @@ class DeviceHandler(BaseHandler):
|
|||
|
||||
yield self.notify_device_update(user_id, [device_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_all_devices_for_user(self, user_id, except_device_id=None):
|
||||
"""Delete all of the user's devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
except_device_id (str|None): optional device id which should not
|
||||
be deleted
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
device_ids = device_map.keys()
|
||||
if except_device_id is not None:
|
||||
device_ids = [d for d in device_ids if d != except_device_id]
|
||||
yield self.delete_devices(user_id, device_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_devices(self, user_id, device_ids):
|
||||
""" Delete several devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_ids (str): The list of device IDs to delete
|
||||
device_ids (List[str]): The list of device IDs to delete
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
|
|
|
@ -375,6 +375,12 @@ class GroupsLocalHandler(object):
|
|||
def get_publicised_groups_for_user(self, user_id):
|
||||
if self.hs.is_mine_id(user_id):
|
||||
result = yield self.store.get_publicised_groups_for_user(user_id)
|
||||
|
||||
# Check AS associated groups for this user - this depends on the
|
||||
# RegExps in the AS registration file (under `users`)
|
||||
for app_service in self.store.get_app_services():
|
||||
result.extend(app_service.get_groups_for_user(user_id))
|
||||
|
||||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||
|
@ -415,4 +421,9 @@ class GroupsLocalHandler(object):
|
|||
uid
|
||||
)
|
||||
|
||||
# Check AS associated groups for this user - this depends on the
|
||||
# RegExps in the AS registration file (under `users`)
|
||||
for app_service in self.store.get_app_services():
|
||||
results[uid].extend(app_service.get_groups_for_user(uid))
|
||||
|
||||
defer.returnValue({"users": results})
|
||||
|
|
|
@ -27,7 +27,7 @@ from synapse.types import (
|
|||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
@ -163,7 +163,7 @@ class InitialSyncHandler(BaseHandler):
|
|||
lambda states: states[event.event_id]
|
||||
)
|
||||
|
||||
(messages, token), current_state = yield preserve_context_over_deferred(
|
||||
(messages, token), current_state = yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store.get_recent_events_for_room)(
|
||||
|
|
|
@ -1199,7 +1199,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
|
|||
)
|
||||
changed = True
|
||||
else:
|
||||
# We expect to be poked occaisonally by the other side.
|
||||
# We expect to be poked occasionally by the other side.
|
||||
# This is to protect against forgetful/buggy servers, so that
|
||||
# no one gets stuck online forever.
|
||||
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
|
||||
|
|
|
@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler):
|
|||
"profile", self.on_profile_query
|
||||
)
|
||||
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
|
||||
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler):
|
|||
target_user.localpart, new_displayname
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
yield self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler):
|
|||
target_user.localpart, new_avatar_url
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
yield self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler):
|
|||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
self._next_generated_user_id = None
|
||||
|
@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler):
|
|||
),
|
||||
admin=admin,
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
profile = yield self.store.get_profileinfo(localpart)
|
||||
yield self.user_directory_handler.handle_local_profile_change(
|
||||
user_id, profile
|
||||
)
|
||||
|
||||
else:
|
||||
# autogen a sequential user ID
|
||||
attempts = 0
|
||||
|
|
|
@ -154,6 +154,8 @@ class RoomListHandler(BaseHandler):
|
|||
# We want larger rooms to be first, hence negating num_joined_users
|
||||
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
|
||||
|
||||
logger.info("Getting ordering for %i rooms since %s",
|
||||
len(room_ids), stream_token)
|
||||
yield concurrently_execute(get_order_for_room, room_ids, 10)
|
||||
|
||||
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
|
||||
|
@ -181,34 +183,42 @@ class RoomListHandler(BaseHandler):
|
|||
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
|
||||
rooms_to_scan.reverse()
|
||||
|
||||
# Actually generate the entries. _append_room_entry_to_chunk will append to
|
||||
# chunk but will stop if len(chunk) > limit
|
||||
chunk = []
|
||||
if limit and not search_filter:
|
||||
logger.info("After sorting and filtering, %i rooms remain",
|
||||
len(rooms_to_scan))
|
||||
|
||||
# _append_room_entry_to_chunk will append to chunk but will stop if
|
||||
# len(chunk) > limit
|
||||
#
|
||||
# Normally we will generate enough results on the first iteration here,
|
||||
# but if there is a search filter, _append_room_entry_to_chunk may
|
||||
# filter some results out, in which case we loop again.
|
||||
#
|
||||
# We don't want to scan over the entire range either as that
|
||||
# would potentially waste a lot of work.
|
||||
#
|
||||
# XXX if there is no limit, we may end up DoSing the server with
|
||||
# calls to get_current_state_ids for every single room on the
|
||||
# server. Surely we should cap this somehow?
|
||||
#
|
||||
if limit:
|
||||
step = limit + 1
|
||||
for i in xrange(0, len(rooms_to_scan), step):
|
||||
# We iterate here because the vast majority of cases we'll stop
|
||||
# at first iteration, but occaisonally _append_room_entry_to_chunk
|
||||
# won't append to the chunk and so we need to loop again.
|
||||
# We don't want to scan over the entire range either as that
|
||||
# would potentially waste a lot of work.
|
||||
yield concurrently_execute(
|
||||
lambda r: self._append_room_entry_to_chunk(
|
||||
r, rooms_to_num_joined[r],
|
||||
chunk, limit, search_filter
|
||||
),
|
||||
rooms_to_scan[i:i + step], 10
|
||||
)
|
||||
if len(chunk) >= limit + 1:
|
||||
break
|
||||
else:
|
||||
step = len(rooms_to_scan)
|
||||
|
||||
chunk = []
|
||||
for i in xrange(0, len(rooms_to_scan), step):
|
||||
batch = rooms_to_scan[i:i + step]
|
||||
logger.info("Processing %i rooms for result", len(batch))
|
||||
yield concurrently_execute(
|
||||
lambda r: self._append_room_entry_to_chunk(
|
||||
r, rooms_to_num_joined[r],
|
||||
chunk, limit, search_filter
|
||||
),
|
||||
rooms_to_scan, 5
|
||||
batch, 5,
|
||||
)
|
||||
logger.info("Now %i rooms in result", len(chunk))
|
||||
if len(chunk) >= limit + 1:
|
||||
break
|
||||
|
||||
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
|
||||
|
||||
|
|
56
synapse/handlers/set_password.py
Normal file
56
synapse/handlers/set_password.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SetPasswordHandler(BaseHandler):
|
||||
"""Handler which deals with changing user account passwords"""
|
||||
def __init__(self, hs):
|
||||
super(SetPasswordHandler, self).__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = self._auth_handler.hash(newpassword)
|
||||
|
||||
except_device_id = requester.device_id if requester else None
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
|
||||
# we want to log out all of the user's other sessions. First delete
|
||||
# all his other devices.
|
||||
yield self._device_handler.delete_all_devices_for_user(
|
||||
user_id, except_device_id=except_device_id,
|
||||
)
|
||||
|
||||
# and now delete any access tokens which weren't associated with
|
||||
# devices (or were associated with this device).
|
||||
yield self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, except_token_id=except_access_token_id,
|
||||
)
|
|
@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
|
|||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.async import sleep
|
||||
from synapse.types import get_localpart_from_id
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDirectoyHandler(object):
|
||||
class UserDirectoryHandler(object):
|
||||
"""Handles querying of and keeping updated the user_directory.
|
||||
|
||||
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
|
||||
|
@ -41,9 +42,10 @@ class UserDirectoyHandler(object):
|
|||
one public room.
|
||||
"""
|
||||
|
||||
INITIAL_SLEEP_MS = 50
|
||||
INITIAL_SLEEP_COUNT = 100
|
||||
INITIAL_BATCH_SIZE = 100
|
||||
INITIAL_ROOM_SLEEP_MS = 50
|
||||
INITIAL_ROOM_SLEEP_COUNT = 100
|
||||
INITIAL_ROOM_BATCH_SIZE = 100
|
||||
INITIAL_USER_SLEEP_MS = 10
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -53,6 +55,7 @@ class UserDirectoyHandler(object):
|
|||
self.notifier = hs.get_notifier()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.update_user_directory = hs.config.update_user_directory
|
||||
self.search_all_users = hs.config.user_directory_search_all_users
|
||||
|
||||
# When start up for the first time we need to populate the user_directory.
|
||||
# This is a set of user_id's we've inserted already
|
||||
|
@ -110,6 +113,15 @@ class UserDirectoyHandler(object):
|
|||
finally:
|
||||
self._is_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_local_profile_change(self, user_id, profile):
|
||||
"""Called to update index of our local user profiles when they change
|
||||
irrespective of any rooms the user may be in.
|
||||
"""
|
||||
yield self.store.update_profile_in_user_dir(
|
||||
user_id, profile.display_name, profile.avatar_url, None,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _unsafe_process(self):
|
||||
# If self.pos is None then means we haven't fetched it from DB
|
||||
|
@ -148,16 +160,30 @@ class UserDirectoyHandler(object):
|
|||
room_ids = yield self.store.get_all_rooms()
|
||||
|
||||
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
|
||||
num_processed_rooms = 1
|
||||
num_processed_rooms = 0
|
||||
|
||||
for room_id in room_ids:
|
||||
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
|
||||
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
|
||||
yield self._handle_initial_room(room_id)
|
||||
num_processed_rooms += 1
|
||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
|
||||
logger.info("Processed all rooms.")
|
||||
|
||||
if self.search_all_users:
|
||||
num_processed_users = 0
|
||||
user_ids = yield self.store.get_all_local_users()
|
||||
logger.info("Doing initial update of user directory. %d users", len(user_ids))
|
||||
for user_id in user_ids:
|
||||
# We add profiles for all users even if they don't match the
|
||||
# include pattern, just in case we want to change it in future
|
||||
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
|
||||
yield self._handle_local_user(user_id)
|
||||
num_processed_users += 1
|
||||
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
|
||||
|
||||
logger.info("Processed all users")
|
||||
|
||||
self.initially_handled_users = None
|
||||
self.initially_handled_users_in_public = None
|
||||
self.initially_handled_users_share = None
|
||||
|
@ -201,8 +227,8 @@ class UserDirectoyHandler(object):
|
|||
to_update = set()
|
||||
count = 0
|
||||
for user_id in user_ids:
|
||||
if count % self.INITIAL_SLEEP_COUNT == 0:
|
||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
count += 1
|
||||
|
@ -216,8 +242,8 @@ class UserDirectoyHandler(object):
|
|||
if user_id == other_user_id:
|
||||
continue
|
||||
|
||||
if count % self.INITIAL_SLEEP_COUNT == 0:
|
||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
||||
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||
count += 1
|
||||
|
||||
user_set = (user_id, other_user_id)
|
||||
|
@ -237,13 +263,13 @@ class UserDirectoyHandler(object):
|
|||
else:
|
||||
self.initially_handled_users_share_private_room.add(user_set)
|
||||
|
||||
if len(to_insert) > self.INITIAL_BATCH_SIZE:
|
||||
if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
|
||||
yield self.store.add_users_who_share_room(
|
||||
room_id, not is_public, to_insert,
|
||||
)
|
||||
to_insert.clear()
|
||||
|
||||
if len(to_update) > self.INITIAL_BATCH_SIZE:
|
||||
if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
|
||||
yield self.store.update_users_who_share_room(
|
||||
room_id, not is_public, to_update,
|
||||
)
|
||||
|
@ -384,15 +410,29 @@ class UserDirectoyHandler(object):
|
|||
for user_id in users:
|
||||
yield self._handle_remove_user(room_id, user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_local_user(self, user_id):
|
||||
"""Adds a new local roomless user into the user_directory_search table.
|
||||
Used to populate up the user index when we have an
|
||||
user_directory_search_all_users specified.
|
||||
"""
|
||||
logger.debug("Adding new local user to dir, %r", user_id)
|
||||
|
||||
profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
|
||||
|
||||
row = yield self.store.get_user_in_directory(user_id)
|
||||
if not row:
|
||||
yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_user(self, room_id, user_id, profile):
|
||||
"""Called when we might need to add user to directory
|
||||
|
||||
Args:
|
||||
room_id (str): room_id that user joined or started being public that
|
||||
room_id (str): room_id that user joined or started being public
|
||||
user_id (str)
|
||||
"""
|
||||
logger.debug("Adding user to dir, %r", user_id)
|
||||
logger.debug("Adding new user to dir, %r", user_id)
|
||||
|
||||
row = yield self.store.get_user_in_directory(user_id)
|
||||
if not row:
|
||||
|
@ -407,7 +447,7 @@ class UserDirectoyHandler(object):
|
|||
if not row:
|
||||
yield self.store.add_users_to_public_room(room_id, [user_id])
|
||||
else:
|
||||
logger.debug("Not adding user to public dir, %r", user_id)
|
||||
logger.debug("Not adding new user to public dir, %r", user_id)
|
||||
|
||||
# Now we update users who share rooms with users. We do this by getting
|
||||
# all the current users in the room and seeing which aren't already
|
||||
|
|
|
@ -362,8 +362,10 @@ def _get_hosts_for_srv_record(dns_client, host):
|
|||
return res
|
||||
|
||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
||||
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
||||
d1 = dns_client.lookupAddress(host).addCallbacks(
|
||||
cb, eb, errbackArgs=("A", ))
|
||||
d2 = dns_client.lookupIPV6Address(host).addCallbacks(
|
||||
cb, eb, errbackArgs=("AAAA", ))
|
||||
results = yield defer.DeferredList(
|
||||
[d1, d2], consumeErrors=True)
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ from canonicaljson import (
|
|||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python import failure
|
||||
from twisted.web import server, resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.web.util import redirectTo
|
||||
|
@ -131,12 +132,17 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
|||
version_string=self.version_string,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed handle request %s.%s on %r: %r",
|
||||
# failure.Failure() fishes the original Failure out
|
||||
# of our stack, and thus gives us a sensible stack
|
||||
# trace.
|
||||
f = failure.Failure()
|
||||
logger.error(
|
||||
"Failed handle request %s.%s on %r: %r: %s",
|
||||
request_handler.__module__,
|
||||
request_handler.__name__,
|
||||
self,
|
||||
request
|
||||
request,
|
||||
f.getTraceback().rstrip(),
|
||||
)
|
||||
respond_with_json(
|
||||
request,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
|
||||
|
@ -81,6 +82,7 @@ class ModuleApi(object):
|
|||
reg = self.hs.get_handlers().registration_handler
|
||||
return reg.register(localpart=localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invalidate_access_token(self, access_token):
|
||||
"""Invalidate an access token for a user
|
||||
|
||||
|
@ -94,8 +96,16 @@ class ModuleApi(object):
|
|||
Raises:
|
||||
synapse.api.errors.AuthError: the access token is invalid
|
||||
"""
|
||||
|
||||
return self._auth_handler.delete_access_token(access_token)
|
||||
# see if the access token corresponds to a device
|
||||
user_info = yield self._auth.get_user_by_access_token(access_token)
|
||||
device_id = user_info.get("device_id")
|
||||
user_id = user_info["user"].to_string()
|
||||
if device_id:
|
||||
# delete the device, which will also delete its access tokens
|
||||
yield self.hs.get_device_handler().delete_device(user_id, device_id)
|
||||
else:
|
||||
# no associated device. Just delete the access token.
|
||||
yield self._auth_handler.delete_access_token(access_token)
|
||||
|
||||
def run_db_interaction(self, desc, func, *args, **kwargs):
|
||||
"""Run a function with a database connection
|
||||
|
|
|
@ -255,9 +255,7 @@ class Notifier(object):
|
|||
)
|
||||
|
||||
if self.federation_sender:
|
||||
preserve_fn(self.federation_sender.notify_new_events)(
|
||||
room_stream_id
|
||||
)
|
||||
self.federation_sender.notify_new_events(room_stream_id)
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
self._user_joined_room(event.state_key, event.room_id)
|
||||
|
@ -297,8 +295,7 @@ class Notifier(object):
|
|||
def on_new_replication_data(self):
|
||||
"""Used to inform replication listeners that something has happend
|
||||
without waking up any of the normal user event streams"""
|
||||
with PreserveLoggingContext():
|
||||
self.notify_replication()
|
||||
self.notify_replication()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
||||
|
@ -516,8 +513,14 @@ class Notifier(object):
|
|||
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
||||
deferred.callback(None)
|
||||
|
||||
for cb in self.replication_callbacks:
|
||||
preserve_fn(cb)()
|
||||
# the callbacks may well outlast the current request, so we run
|
||||
# them in the sentinel logcontext.
|
||||
#
|
||||
# (ideally it would be up to the callbacks to know if they were
|
||||
# starting off background processes and drop the logcontext
|
||||
# accordingly, but that requires more changes)
|
||||
for cb in self.replication_callbacks:
|
||||
cb()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_replication(self, callback, timeout):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -295,7 +296,7 @@ class HttpPusher(object):
|
|||
if event.type == 'm.room.member':
|
||||
d['notification']['membership'] = event.content['membership']
|
||||
d['notification']['user_is_target'] = event.state_key == self.user_id
|
||||
if not self.hs.config.push_redact_content and 'content' in event:
|
||||
if self.hs.config.push_include_content and 'content' in event:
|
||||
d['notification']['content'] = event.content
|
||||
|
||||
# We no longer send aliases separately, instead, we send the human
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from .pusher import PusherFactory
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
import logging
|
||||
|
@ -103,19 +103,25 @@ class PusherPool:
|
|||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
|
||||
all = yield self.store.get_all_pushers()
|
||||
logger.info(
|
||||
"Removing all pushers for user %s except access tokens id %r",
|
||||
user_id, except_access_token_id
|
||||
)
|
||||
for p in all:
|
||||
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
|
||||
def remove_pushers_by_access_token(self, user_id, access_tokens):
|
||||
"""Remove the pushers for a given user corresponding to a set of
|
||||
access_tokens.
|
||||
|
||||
Args:
|
||||
user_id (str): user to remove pushers for
|
||||
access_tokens (Iterable[int]): access token *ids* to remove pushers
|
||||
for
|
||||
"""
|
||||
tokens = set(access_tokens)
|
||||
for p in (yield self.store.get_pushers_by_user_id(user_id)):
|
||||
if p['access_token'] in tokens:
|
||||
logger.info(
|
||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||
p['app_id'], p['pushkey'], p['user_name']
|
||||
)
|
||||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
yield self.remove_pusher(
|
||||
p['app_id'], p['pushkey'], p['user_name'],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_notifications(self, min_stream_id, max_stream_id):
|
||||
|
@ -136,7 +142,7 @@ class PusherPool:
|
|||
)
|
||||
)
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
||||
yield make_deferred_yieldable(defer.gatherResults(deferreds))
|
||||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_notifications")
|
||||
|
||||
|
@ -161,7 +167,7 @@ class PusherPool:
|
|||
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
|
||||
)
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
||||
yield make_deferred_yieldable(defer.gatherResults(deferreds))
|
||||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
||||
|
|
|
@ -12,20 +12,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
from synapse.storage.state import StateStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.state import StateGroupReadStore
|
||||
from synapse.storage.stream import StreamStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
import logging
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,7 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
# the method descriptor on the DataStore and chuck them into our class.
|
||||
|
||||
|
||||
class SlavedEventStore(BaseSlavedStore):
|
||||
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedEventStore, self).__init__(db_conn, hs)
|
||||
|
@ -90,25 +88,9 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
_get_unread_counts_by_pos_txn = (
|
||||
DataStore._get_unread_counts_by_pos_txn.__func__
|
||||
)
|
||||
_get_state_group_for_events = (
|
||||
StateStore.__dict__["_get_state_group_for_events"]
|
||||
)
|
||||
_get_state_group_for_event = (
|
||||
StateStore.__dict__["_get_state_group_for_event"]
|
||||
)
|
||||
_get_state_groups_from_groups = (
|
||||
StateStore.__dict__["_get_state_groups_from_groups"]
|
||||
)
|
||||
_get_state_groups_from_groups_txn = (
|
||||
DataStore._get_state_groups_from_groups_txn.__func__
|
||||
)
|
||||
get_recent_event_ids_for_room = (
|
||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
get_current_state_ids = (
|
||||
StateStore.__dict__["get_current_state_ids"]
|
||||
)
|
||||
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
|
||||
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
|
||||
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
||||
|
||||
|
@ -134,12 +116,6 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
DataStore.get_room_events_stream_for_room.__func__
|
||||
)
|
||||
get_events_around = DataStore.get_events_around.__func__
|
||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
||||
get_state_groups = DataStore.get_state_groups.__func__
|
||||
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
||||
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
||||
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
||||
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
|
||||
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
||||
_get_joined_users_from_context = (
|
||||
|
@ -169,10 +145,7 @@ class SlavedEventStore(BaseSlavedStore):
|
|||
_get_rooms_for_user_where_membership_is_txn = (
|
||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
||||
)
|
||||
_get_state_for_groups = DataStore._get_state_for_groups.__func__
|
||||
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
|
||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
||||
|
||||
get_backfill_events = DataStore.get_backfill_events.__func__
|
||||
_get_backfill_events = DataStore._get_backfill_events.__func__
|
||||
|
|
|
@ -216,11 +216,12 @@ class ReplicationStreamer(object):
|
|||
self.federation_sender.federation_ack(token)
|
||||
|
||||
@measure_func("repl.on_user_sync")
|
||||
@defer.inlineCallbacks
|
||||
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
||||
"""A client has started/stopped syncing on a worker.
|
||||
"""
|
||||
user_sync_counter.inc()
|
||||
self.presence_handler.update_external_syncs_row(
|
||||
yield self.presence_handler.update_external_syncs_row(
|
||||
conn_id, user_id, is_syncing, last_sync_ms,
|
||||
)
|
||||
|
||||
|
@ -244,11 +245,12 @@ class ReplicationStreamer(object):
|
|||
getattr(self.store, cache_func).invalidate(tuple(keys))
|
||||
|
||||
@measure_func("repl.on_user_ip")
|
||||
@defer.inlineCallbacks
|
||||
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||
"""The client saw a user request
|
||||
"""
|
||||
user_ip_cache_counter.inc()
|
||||
self.store.insert_client_ip(
|
||||
yield self.store.insert_client_ip(
|
||||
user_id, access_token, ip, user_agent, device_id, last_seen,
|
||||
)
|
||||
|
||||
|
|
|
@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, target_user_id):
|
||||
|
@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
|||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
yield self._auth_handler.deactivate_account(target_user_id)
|
||||
yield self._deactivate_account_handler.deactivate_account(target_user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
|
|||
super(ResetPasswordRestServlet, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._set_password_handler = hs.get_set_password_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, target_user_id):
|
||||
|
@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
|
|||
|
||||
logger.info("new_password: %r", new_password)
|
||||
|
||||
yield self.auth_handler.set_password(
|
||||
yield self._set_password_handler.set_password(
|
||||
target_user_id, new_password, requester
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import get_access_token_from_request
|
||||
from synapse.api.errors import AuthError
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
|
@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(LogoutRestServlet, self).__init__(hs)
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
access_token = get_access_token_from_request(request)
|
||||
yield self._auth_handler.delete_access_token(access_token)
|
||||
try:
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
except AuthError:
|
||||
# this implies the access token has already been deleted.
|
||||
pass
|
||||
else:
|
||||
if requester.device_id is None:
|
||||
# the acccess token wasn't associated with a device.
|
||||
# Just delete the access token
|
||||
access_token = get_access_token_from_request(request)
|
||||
yield self._auth_handler.delete_access_token(access_token)
|
||||
else:
|
||||
yield self._device_handler.delete_device(
|
||||
requester.user.to_string(), requester.device_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
|||
super(LogoutAllRestServlet, self).__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
|||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# first delete all of the user's devices
|
||||
yield self._device_handler.delete_all_devices_for_user(user_id)
|
||||
|
||||
# .. and then delete any access tokens which weren't associated with
|
||||
# devices.
|
||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
|
||||
"""This module contains base REST classes for constructing client v1 servlets.
|
||||
"""
|
||||
|
||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
||||
import logging
|
||||
import re
|
||||
|
||||
import logging
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
|
|||
filter_json['room']['timeline']["limit"] = min(
|
||||
filter_json['room']['timeline']['limit'],
|
||||
filter_timeline_limit)
|
||||
|
||||
|
||||
def interactive_auth_handler(orig):
|
||||
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
||||
|
||||
Takes a on_POST method which returns a deferred (errcode, body) response
|
||||
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
||||
a 401 response.
|
||||
|
||||
Normal usage is:
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
# ...
|
||||
yield self.auth_handler.check_auth
|
||||
"""
|
||||
def wrapped(*args, **kwargs):
|
||||
res = defer.maybeDeferred(orig, *args, **kwargs)
|
||||
res.addErrback(_catch_incomplete_interactive_auth)
|
||||
return res
|
||||
return wrapped
|
||||
|
||||
|
||||
def _catch_incomplete_interactive_auth(f):
|
||||
"""helper for interactive_auth_handler
|
||||
|
||||
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
|
||||
|
||||
Args:
|
||||
f (failure.Failure):
|
||||
"""
|
||||
f.trap(InteractiveAuthIncompleteError)
|
||||
return 401, f.value.result
|
||||
|
|
|
@ -19,14 +19,14 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.auth import has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, assert_params_in_request,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from ._base import client_v2_patterns
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -98,56 +98,61 @@ class PasswordRestServlet(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.datastore = self.hs.get_datastore()
|
||||
self._set_password_handler = hs.get_set_password_handler()
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
[LoginType.EMAIL_IDENTITY],
|
||||
[LoginType.MSISDN],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
# there are two possibilities here. Either the user does not have an
|
||||
# access token, and needs to do a password reset; or they have one and
|
||||
# need to validate their identity.
|
||||
#
|
||||
# In the first case, we offer a couple of means of identifying
|
||||
# themselves (email and msisdn, though it's unclear if msisdn actually
|
||||
# works).
|
||||
#
|
||||
# In the second case, we require a password to confirm their identity.
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
user_id = None
|
||||
requester = None
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
if has_access_token(request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
if user_id != result[LoginType.PASSWORD]:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
elif LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
if 'medium' not in threepid or 'address' not in threepid:
|
||||
raise SynapseError(500, "Malformed threepid")
|
||||
if threepid['medium'] == 'email':
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
threepid['address'] = threepid['address'].lower()
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
||||
threepid['medium'], threepid['address']
|
||||
params = yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
if not threepid_user_id:
|
||||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||
user_id = threepid_user_id
|
||||
user_id = requester.user.to_string()
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
requester = None
|
||||
result, params, _ = yield self.auth_handler.check_auth(
|
||||
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
|
||||
body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
if LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
if 'medium' not in threepid or 'address' not in threepid:
|
||||
raise SynapseError(500, "Malformed threepid")
|
||||
if threepid['medium'] == 'email':
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
threepid['address'] = threepid['address'].lower()
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
||||
threepid['medium'], threepid['address']
|
||||
)
|
||||
if not threepid_user_id:
|
||||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||
user_id = threepid_user_id
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
if 'new_password' not in params:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
new_password = params['new_password']
|
||||
|
||||
yield self.auth_handler.set_password(
|
||||
yield self._set_password_handler.set_password(
|
||||
user_id, new_password, requester
|
||||
)
|
||||
|
||||
|
@ -161,52 +166,32 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
PATTERNS = client_v2_patterns("/account/deactivate$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DeactivateAccountRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
super(DeactivateAccountRestServlet, self).__init__()
|
||||
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# if the caller provides an access token, it ought to be valid.
|
||||
requester = None
|
||||
if has_access_token(request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
) # type: synapse.types.Requester
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# allow ASes to dectivate their own users
|
||||
if requester and requester.app_service:
|
||||
yield self.auth_handler.deactivate_account(
|
||||
if requester.app_service:
|
||||
yield self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
user_id = result[LoginType.PASSWORD]
|
||||
# if using password, they should also be logged in
|
||||
if requester is None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Deactivate account requires an access_token",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
if requester.user.to_string() != user_id:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
yield self.auth_handler.deactivate_account(user_id)
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
yield self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@ import logging
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api import constants, errors
|
||||
from synapse.api import errors
|
||||
from synapse.http import servlet
|
||||
from ._base import client_v2_patterns
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||
self.device_handler = hs.get_device_handler()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
except errors.SynapseError as e:
|
||||
|
@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
|||
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[constants.LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_devices(
|
||||
requester.user.to_string(),
|
||||
body['devices'],
|
||||
|
@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
)
|
||||
defer.returnValue((200, device))
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet):
|
|||
else:
|
||||
raise
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[constants.LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
# check that the UI auth matched the access token
|
||||
user_id = result[constants.LoginType.PASSWORD]
|
||||
if user_id != requester.user.to_string():
|
||||
raise errors.AuthError(403, "Invalid auth")
|
||||
|
||||
yield self.device_handler.delete_device(user_id, device_id)
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(), device_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
@ -38,7 +38,7 @@ class GroupServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
group_description = yield self.groups_handler.get_group_profile(
|
||||
|
@ -74,7 +74,7 @@ class GroupSummaryServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
get_group_summary = yield self.groups_handler.get_group_summary(
|
||||
|
@ -148,7 +148,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_category(
|
||||
|
@ -200,7 +200,7 @@ class GroupCategoriesServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_categories(
|
||||
|
@ -225,7 +225,7 @@ class GroupRoleServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_role(
|
||||
|
@ -277,7 +277,7 @@ class GroupRolesServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_roles(
|
||||
|
@ -348,7 +348,7 @@ class GroupRoomServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
|
||||
|
@ -369,7 +369,7 @@ class GroupUsersServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
|
||||
|
@ -672,7 +672,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||
user_id
|
||||
|
@ -697,7 +697,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
user_ids = content["user_ids"]
|
||||
|
@ -724,7 +724,7 @@ class GroupsForUserServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_joined_groups(requester_user_id)
|
||||
|
|
|
@ -27,7 +27,7 @@ from synapse.http.servlet import (
|
|||
)
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
import logging
|
||||
import hmac
|
||||
|
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
|
|||
self.device_handler = hs.get_device_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
@interactive_auth_handler
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
|
|||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||
])
|
||||
|
||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, auth_result))
|
||||
return
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info(
|
||||
"Already registered user ID %r for this session",
|
||||
|
|
|
@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
|
|||
"r0.0.1",
|
||||
"r0.1.0",
|
||||
"r0.2.0",
|
||||
"r0.3.0",
|
||||
]
|
||||
})
|
||||
|
||||
|
|
|
@ -25,7 +25,8 @@ from synapse.util.stringutils import random_string
|
|||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.http.client import SpiderHttpClient
|
||||
from synapse.http.server import (
|
||||
request_handler, respond_with_json_bytes
|
||||
request_handler, respond_with_json_bytes,
|
||||
respond_with_json,
|
||||
)
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
@ -78,6 +79,9 @@ class PreviewUrlResource(Resource):
|
|||
self._expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
return respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
@ -348,11 +352,16 @@ class PreviewUrlResource(Resource):
|
|||
def _expire_url_cache_data(self):
|
||||
"""Clean up expired url cache content, media and thumbnails.
|
||||
"""
|
||||
|
||||
# TODO: Delete from backup media store
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
logger.info("Running url preview cache expiry")
|
||||
|
||||
if not (yield self.store.has_completed_background_updates()):
|
||||
logger.info("Still running DB updates; skipping expiry")
|
||||
return
|
||||
|
||||
# First we delete expired url cache entries
|
||||
media_ids = yield self.store.get_expired_url_cache(now)
|
||||
|
||||
|
@ -426,8 +435,7 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
yield self.store.delete_url_cache_media(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
|
|
|
@ -39,18 +39,20 @@ from synapse.federation.transaction_queue import TransactionQueue
|
|||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.room_list import RoomListHandler
|
||||
from synapse.handlers.set_password import SetPasswordHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.user_directory import UserDirectoyHandler
|
||||
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
|
@ -60,7 +62,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
|||
from synapse.notifier import Notifier
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
from synapse.rest.media.v1.media_repository import (
|
||||
MediaRepository,
|
||||
MediaRepositoryResource,
|
||||
)
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.streams.events import EventSources
|
||||
|
@ -90,17 +95,12 @@ class HomeServer(object):
|
|||
"""
|
||||
|
||||
DEPENDENCIES = [
|
||||
'config',
|
||||
'clock',
|
||||
'http_client',
|
||||
'db_pool',
|
||||
'persistence_service',
|
||||
'replication_layer',
|
||||
'datastore',
|
||||
'handlers',
|
||||
'v1auth',
|
||||
'auth',
|
||||
'rest_servlet_factory',
|
||||
'state_handler',
|
||||
'presence_handler',
|
||||
'sync_handler',
|
||||
|
@ -117,19 +117,10 @@ class HomeServer(object):
|
|||
'application_service_handler',
|
||||
'device_message_handler',
|
||||
'profile_handler',
|
||||
'deactivate_account_handler',
|
||||
'set_password_handler',
|
||||
'notifier',
|
||||
'distributor',
|
||||
'client_resource',
|
||||
'resource_for_federation',
|
||||
'resource_for_static_content',
|
||||
'resource_for_web_client',
|
||||
'resource_for_content_repo',
|
||||
'resource_for_server_key',
|
||||
'resource_for_server_key_v2',
|
||||
'resource_for_media_repository',
|
||||
'resource_for_metrics',
|
||||
'event_sources',
|
||||
'ratelimiter',
|
||||
'keyring',
|
||||
'pusherpool',
|
||||
'event_builder_factory',
|
||||
|
@ -137,6 +128,7 @@ class HomeServer(object):
|
|||
'http_client_context_factory',
|
||||
'simple_http_client',
|
||||
'media_repository',
|
||||
'media_repository_resource',
|
||||
'federation_transport_client',
|
||||
'federation_sender',
|
||||
'receipts_handler',
|
||||
|
@ -183,6 +175,21 @@ class HomeServer(object):
|
|||
def is_mine_id(self, string):
|
||||
return string.split(":", 1)[1] == self.hostname
|
||||
|
||||
def get_clock(self):
|
||||
return self.clock
|
||||
|
||||
def get_datastore(self):
|
||||
return self.datastore
|
||||
|
||||
def get_config(self):
|
||||
return self.config
|
||||
|
||||
def get_distributor(self):
|
||||
return self.distributor
|
||||
|
||||
def get_ratelimiter(self):
|
||||
return self.ratelimiter
|
||||
|
||||
def build_replication_layer(self):
|
||||
return initialize_http_replication(self)
|
||||
|
||||
|
@ -265,6 +272,12 @@ class HomeServer(object):
|
|||
def build_profile_handler(self):
|
||||
return ProfileHandler(self)
|
||||
|
||||
def build_deactivate_account_handler(self):
|
||||
return DeactivateAccountHandler(self)
|
||||
|
||||
def build_set_password_handler(self):
|
||||
return SetPasswordHandler(self)
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
|
@ -294,6 +307,11 @@ class HomeServer(object):
|
|||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def build_media_repository_resource(self):
|
||||
# build the media repo resource. This indirects through the HomeServer
|
||||
# to ensure that we only have a single instance of
|
||||
return MediaRepositoryResource(self)
|
||||
|
||||
def build_media_repository(self):
|
||||
return MediaRepository(self)
|
||||
|
||||
|
@ -321,7 +339,7 @@ class HomeServer(object):
|
|||
return ActionGenerator(self)
|
||||
|
||||
def build_user_directory_handler(self):
|
||||
return UserDirectoyHandler(self)
|
||||
return UserDirectoryHandler(self)
|
||||
|
||||
def build_groups_local_handler(self):
|
||||
return GroupsLocalHandler(self)
|
||||
|
|
|
@ -3,10 +3,14 @@ import synapse.federation.transaction_queue
|
|||
import synapse.federation.transport.client
|
||||
import synapse.handlers
|
||||
import synapse.handlers.auth
|
||||
import synapse.handlers.deactivate_account
|
||||
import synapse.handlers.device
|
||||
import synapse.handlers.e2e_keys
|
||||
import synapse.storage
|
||||
import synapse.handlers.set_password
|
||||
import synapse.rest.media.v1.media_repository
|
||||
import synapse.state
|
||||
import synapse.storage
|
||||
|
||||
|
||||
class HomeServer(object):
|
||||
def get_auth(self) -> synapse.api.auth.Auth:
|
||||
|
@ -30,8 +34,20 @@ class HomeServer(object):
|
|||
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||
pass
|
||||
|
||||
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||
pass
|
||||
|
||||
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
|
||||
pass
|
||||
|
||||
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
|
||||
pass
|
||||
|
||||
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
|
||||
pass
|
||||
|
||||
def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
|
||||
pass
|
||||
|
||||
def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
|
||||
pass
|
||||
|
|
|
@ -16,8 +16,6 @@ import logging
|
|||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
from synapse.util.caches.descriptors import Cache
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
import synapse.metrics
|
||||
|
@ -180,10 +178,6 @@ class SQLBaseStore(object):
|
|||
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
|
||||
self._state_group_cache = DictionaryCache(
|
||||
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
||||
)
|
||||
|
||||
self._event_fetch_lock = threading.Condition()
|
||||
self._event_fetch_list = []
|
||||
self._event_fetch_ongoing = 0
|
||||
|
@ -475,23 +469,53 @@ class SQLBaseStore(object):
|
|||
|
||||
txn.executemany(sql, vals)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_upsert(self, table, keyvalues, values,
|
||||
insertion_values={}, desc="_simple_upsert", lock=True):
|
||||
"""
|
||||
|
||||
`lock` should generally be set to True (the default), but can be set
|
||||
to False if either of the following are true:
|
||||
|
||||
* there is a UNIQUE INDEX on the key columns. In this case a conflict
|
||||
will cause an IntegrityError in which case this function will retry
|
||||
the update.
|
||||
|
||||
* we somehow know that we are the only thread which will be updating
|
||||
this table.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): key/values to use when inserting
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
lock (bool): True to lock the table when doing the upsert.
|
||||
Returns:
|
||||
Deferred(bool): True if a new entry was created, False if an
|
||||
existing one was updated.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
||||
lock
|
||||
)
|
||||
attempts = 0
|
||||
while True:
|
||||
try:
|
||||
result = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
||||
lock=lock
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except self.database_engine.module.IntegrityError as e:
|
||||
attempts += 1
|
||||
if attempts >= 5:
|
||||
# don't retry forever, because things other than races
|
||||
# can cause IntegrityErrors
|
||||
raise
|
||||
|
||||
# presumably we raced with another transaction: let's retry.
|
||||
logger.warn(
|
||||
"IntegrityError when upserting into %s; retrying: %s",
|
||||
table, e
|
||||
)
|
||||
|
||||
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
|
||||
lock=True):
|
||||
|
@ -499,7 +523,7 @@ class SQLBaseStore(object):
|
|||
if lock:
|
||||
self.database_engine.lock_table(txn, table)
|
||||
|
||||
# Try to update
|
||||
# First try to update.
|
||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in values),
|
||||
|
@ -508,28 +532,29 @@ class SQLBaseStore(object):
|
|||
sqlargs = values.values() + keyvalues.values()
|
||||
|
||||
txn.execute(sql, sqlargs)
|
||||
if txn.rowcount == 0:
|
||||
# We didn't update and rows so insert a new one
|
||||
allvalues = {}
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(values)
|
||||
allvalues.update(insertion_values)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues)
|
||||
)
|
||||
txn.execute(sql, allvalues.values())
|
||||
|
||||
return True
|
||||
else:
|
||||
if txn.rowcount > 0:
|
||||
# successfully updated at least one row.
|
||||
return False
|
||||
|
||||
# We didn't update any rows so insert a new one
|
||||
allvalues = {}
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(values)
|
||||
allvalues.update(insertion_values)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues)
|
||||
)
|
||||
txn.execute(sql, allvalues.values())
|
||||
# successfully inserted
|
||||
return True
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
|
@ -582,20 +607,18 @@ class SQLBaseStore(object):
|
|||
|
||||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s %(where)s"
|
||||
"SELECT %(retcol)s FROM %(table)s"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
"where": where,
|
||||
}
|
||||
|
||||
txn.execute(sql, keyvalues.values())
|
||||
if keyvalues:
|
||||
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
txn.execute(sql, keyvalues.values())
|
||||
else:
|
||||
txn.execute(sql)
|
||||
|
||||
return [r[0] for r in txn]
|
||||
|
||||
|
@ -606,7 +629,7 @@ class SQLBaseStore(object):
|
|||
|
||||
Args:
|
||||
table (str): table name
|
||||
keyvalues (dict): column names and values to select the rows with
|
||||
keyvalues (dict|None): column names and values to select the rows with
|
||||
retcol (str): column whos value we wish to retrieve.
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore):
|
|||
"""
|
||||
content_json = json.dumps(content)
|
||||
|
||||
def add_account_data_txn(txn, next_id):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so _simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
yield self._simple_upsert(
|
||||
desc="add_room_account_data",
|
||||
table="room_account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
|
@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore):
|
|||
values={
|
||||
"stream_id": next_id,
|
||||
"content": content_json,
|
||||
}
|
||||
},
|
||||
lock=False,
|
||||
)
|
||||
txn.call_after(
|
||||
self._account_data_stream_cache.entity_has_changed,
|
||||
user_id, next_id,
|
||||
)
|
||||
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
|
||||
self._update_max_stream_id(txn, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
yield self.runInteraction(
|
||||
"add_room_account_data", add_account_data_txn, next_id
|
||||
)
|
||||
# it's theoretically possible for the above to succeed and the
|
||||
# below to fail - in which case we might reuse a stream id on
|
||||
# restart, and the above update might not get propagated. That
|
||||
# doesn't sound any worse than the whole update getting lost,
|
||||
# which is what would happen if we combined the two into one
|
||||
# transaction.
|
||||
yield self._update_max_stream_id(next_id)
|
||||
|
||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
|
||||
result = self._account_data_id_gen.get_current_token()
|
||||
defer.returnValue(result)
|
||||
|
@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore):
|
|||
"""
|
||||
content_json = json.dumps(content)
|
||||
|
||||
def add_account_data_txn(txn, next_id):
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so _simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
yield self._simple_upsert(
|
||||
desc="add_user_account_data",
|
||||
table="account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
|
@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore):
|
|||
values={
|
||||
"stream_id": next_id,
|
||||
"content": content_json,
|
||||
}
|
||||
},
|
||||
lock=False,
|
||||
)
|
||||
txn.call_after(
|
||||
self._account_data_stream_cache.entity_has_changed,
|
||||
|
||||
# it's theoretically possible for the above to succeed and the
|
||||
# below to fail - in which case we might reuse a stream id on
|
||||
# restart, and the above update might not get propagated. That
|
||||
# doesn't sound any worse than the whole update getting lost,
|
||||
# which is what would happen if we combined the two into one
|
||||
# transaction.
|
||||
yield self._update_max_stream_id(next_id)
|
||||
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, next_id,
|
||||
)
|
||||
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
|
||||
txn.call_after(
|
||||
self.get_global_account_data_by_type_for_user.invalidate,
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(account_data_type, user_id,)
|
||||
)
|
||||
self._update_max_stream_id(txn, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
yield self.runInteraction(
|
||||
"add_user_account_data", add_account_data_txn, next_id
|
||||
)
|
||||
|
||||
result = self._account_data_id_gen.get_current_token()
|
||||
defer.returnValue(result)
|
||||
|
||||
def _update_max_stream_id(self, txn, next_id):
|
||||
def _update_max_stream_id(self, next_id):
|
||||
"""Update the max stream_id
|
||||
|
||||
Args:
|
||||
txn: The database cursor
|
||||
next_id(int): The the revision to advance to.
|
||||
"""
|
||||
update_max_id_sql = (
|
||||
"UPDATE account_data_max_stream_id"
|
||||
" SET stream_id = ?"
|
||||
" WHERE stream_id < ?"
|
||||
def _update(txn):
|
||||
update_max_id_sql = (
|
||||
"UPDATE account_data_max_stream_id"
|
||||
" SET stream_id = ?"
|
||||
" WHERE stream_id < ?"
|
||||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
return self.runInteraction(
|
||||
"update_account_data_max_stream_id",
|
||||
_update,
|
||||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
||||
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
||||
|
|
|
@ -85,6 +85,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
self._background_update_performance = {}
|
||||
self._background_update_queue = []
|
||||
self._background_update_handlers = {}
|
||||
self._all_done = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_doing_background_updates(self):
|
||||
|
@ -106,8 +107,40 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
"No more background updates to do."
|
||||
" Unscheduling background update task."
|
||||
)
|
||||
self._all_done = True
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def has_completed_background_updates(self):
|
||||
"""Check if all the background updates have completed
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if all background updates have completed
|
||||
"""
|
||||
# if we've previously determined that there is nothing left to do, that
|
||||
# is easy
|
||||
if self._all_done:
|
||||
defer.returnValue(True)
|
||||
|
||||
# obviously, if we have things in our queue, we're not done.
|
||||
if self._background_update_queue:
|
||||
defer.returnValue(False)
|
||||
|
||||
# otherwise, check if there are updates to be run. This is important,
|
||||
# as we may be running on a worker which doesn't perform the bg updates
|
||||
# itself, but still wants to wait for them to happen.
|
||||
updates = yield self._simple_select_onecol(
|
||||
"background_updates",
|
||||
keyvalues=None,
|
||||
retcol="1",
|
||||
desc="check_background_updates",
|
||||
)
|
||||
if not updates:
|
||||
self._all_done = True
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_next_background_update(self, desired_duration_ms):
|
||||
"""Does some amount of work on the next queued background update
|
||||
|
@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
|||
# Sqlite doesn't support concurrent creation of indexes.
|
||||
#
|
||||
# We don't use partial indices on SQLite as it wasn't introduced
|
||||
# until 3.8, and wheezy has 3.7
|
||||
# until 3.8, and wheezy and CentOS 7 have 3.7
|
||||
#
|
||||
# We assume that sqlite doesn't give us invalid indices; however
|
||||
# we may still end up with the index existing but the
|
||||
|
|
|
@ -12,13 +12,23 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
|
||||
|
||||
class MediaRepositoryStore(SQLBaseStore):
|
||||
class MediaRepositoryStore(BackgroundUpdateStore):
|
||||
"""Persistence for attachments and avatars"""
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(MediaRepositoryStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.register_background_index_update(
|
||||
update_name='local_media_repository_url_idx',
|
||||
index_name='local_media_repository_url_idx',
|
||||
table='local_media_repository',
|
||||
columns=['created_ts'],
|
||||
where_clause='url_cache IS NOT NULL',
|
||||
)
|
||||
|
||||
def get_default_thumbnails(self, top_level_type, sub_type):
|
||||
return []
|
||||
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
|
@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore):
|
|||
desc="create_profile",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_profileinfo(self, user_localpart):
|
||||
try:
|
||||
profile = yield self._simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
defer.returnValue(ProfileInfo(None, None))
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
defer.returnValue(
|
||||
ProfileInfo(
|
||||
avatar_url=profile['avatar_url'],
|
||||
display_name=profile['displayname'],
|
||||
)
|
||||
)
|
||||
|
||||
def get_profile_displayname(self, user_localpart):
|
||||
return self._simple_select_one_onecol(
|
||||
table="profiles",
|
||||
|
|
|
@ -204,34 +204,35 @@ class PusherStore(SQLBaseStore):
|
|||
pushkey, pushkey_ts, lang, data, last_stream_ordering,
|
||||
profile_tag=""):
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
def f(txn):
|
||||
newly_inserted = self._simple_upsert_txn(
|
||||
txn,
|
||||
"pushers",
|
||||
{
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_name": user_id,
|
||||
},
|
||||
{
|
||||
"access_token": access_token,
|
||||
"kind": kind,
|
||||
"app_display_name": app_display_name,
|
||||
"device_display_name": device_display_name,
|
||||
"ts": pushkey_ts,
|
||||
"lang": lang,
|
||||
"data": encode_canonical_json(data),
|
||||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
},
|
||||
)
|
||||
if newly_inserted:
|
||||
# get_if_user_has_pusher only cares if the user has
|
||||
# at least *one* pusher.
|
||||
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so _simple_upsert will retry
|
||||
newly_inserted = yield self._simple_upsert(
|
||||
table="pushers",
|
||||
keyvalues={
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_name": user_id,
|
||||
},
|
||||
values={
|
||||
"access_token": access_token,
|
||||
"kind": kind,
|
||||
"app_display_name": app_display_name,
|
||||
"device_display_name": device_display_name,
|
||||
"ts": pushkey_ts,
|
||||
"lang": lang,
|
||||
"data": encode_canonical_json(data),
|
||||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
},
|
||||
desc="add_pusher",
|
||||
lock=False,
|
||||
)
|
||||
|
||||
yield self.runInteraction("add_pusher", f)
|
||||
if newly_inserted:
|
||||
# get_if_user_has_pusher only cares if the user has
|
||||
# at least *one* pusher.
|
||||
self.get_if_user_has_pusher.invalidate(user_id,)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||
|
@ -243,11 +244,19 @@ class PusherStore(SQLBaseStore):
|
|||
"pushers",
|
||||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
|
||||
)
|
||||
self._simple_upsert_txn(
|
||||
|
||||
# it's possible for us to end up with duplicate rows for
|
||||
# (app_id, pushkey, user_id) at different stream_ids, but that
|
||||
# doesn't really matter.
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"deleted_pushers",
|
||||
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
|
||||
{"stream_id": stream_id},
|
||||
table="deleted_pushers",
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
"app_id": app_id,
|
||||
"pushkey": pushkey,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
|
@ -310,9 +319,12 @@ class PusherStore(SQLBaseStore):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def set_throttle_params(self, pusher_id, room_id, params):
|
||||
# no need to lock because `pusher_throttle` has a primary key on
|
||||
# (pusher, room_id) so _simple_upsert will retry
|
||||
yield self._simple_upsert(
|
||||
"pusher_throttle",
|
||||
{"pusher": pusher_id, "room_id": room_id},
|
||||
params,
|
||||
desc="set_throttle_params"
|
||||
desc="set_throttle_params",
|
||||
lock=False,
|
||||
)
|
||||
|
|
|
@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
Returns:
|
||||
defer.Deferred[list[str, str|None]]: a list of the deleted tokens
|
||||
and device IDs
|
||||
defer.Deferred[list[str, int, str|None, int]]: a list of
|
||||
(token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
def f(txn):
|
||||
keyvalues = {
|
||||
|
@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
|||
values.append(except_token_id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
|
||||
"SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
|
||||
values
|
||||
)
|
||||
tokens_and_devices = [(r[0], r[1]) for r in txn]
|
||||
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
|
||||
|
||||
for token, _ in tokens_and_devices:
|
||||
for token, _, _ in tokens_and_devices:
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_access_token, (token,)
|
||||
)
|
||||
|
|
|
@ -13,7 +13,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
||||
-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
|
||||
-- removed and replaced with 46/local_media_repository_url_idx.sql.
|
||||
--
|
||||
-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
||||
|
||||
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
|
||||
-- indices on expressions until 3.9.
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- drop the unique constraint on deleted_pushers so that we can just insert
|
||||
-- into it rather than upserting.
|
||||
|
||||
CREATE TABLE deleted_pushers2 (
|
||||
stream_id BIGINT NOT NULL,
|
||||
app_id TEXT NOT NULL,
|
||||
pushkey TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
|
||||
SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
|
||||
|
||||
DROP TABLE deleted_pushers;
|
||||
ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
|
||||
|
||||
-- create the index after doing the inserts because that's more efficient.
|
||||
-- it also means we can give it the same name as the old one without renaming.
|
||||
CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- register a background update which will recreate the
|
||||
-- local_media_repository_url_idx index.
|
||||
--
|
||||
-- We do this as a bg update not because it is a particularly onerous
|
||||
-- operation, but because we'd like it to be a partial index if possible, and
|
||||
-- the background_index_update code will understand whether we are on
|
||||
-- postgres or sqlite and behave accordingly.
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('local_media_repository_url_idx', '{}');
|
35
synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
Normal file
35
synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
Normal file
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- change the user_directory table to also cover global local user profiles
|
||||
-- rather than just profiles within specific rooms.
|
||||
|
||||
CREATE TABLE user_directory2 (
|
||||
user_id TEXT NOT NULL,
|
||||
room_id TEXT,
|
||||
display_name TEXT,
|
||||
avatar_url TEXT
|
||||
);
|
||||
|
||||
INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
|
||||
SELECT user_id, room_id, display_name, avatar_url from user_directory;
|
||||
|
||||
DROP TABLE user_directory;
|
||||
ALTER TABLE user_directory2 RENAME TO user_directory;
|
||||
|
||||
-- create indexes after doing the inserts because that's more efficient.
|
||||
-- it also means we can give it the same name as the old one without renaming.
|
||||
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
|
||||
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
|
|
@ -13,16 +13,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.stringutils import to_ascii
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
from synapse.util.stringutils import to_ascii
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -40,23 +42,11 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
|
|||
return len(self.delta_ids) if self.delta_ids else 0
|
||||
|
||||
|
||||
class StateStore(SQLBaseStore):
|
||||
""" Keeps track of the state at a given event.
|
||||
class StateGroupReadStore(SQLBaseStore):
|
||||
"""The read-only parts of StateGroupStore
|
||||
|
||||
This is done by the concept of `state groups`. Every event is a assigned
|
||||
a state group (identified by an arbitrary string), which references a
|
||||
collection of state events. The current state of an event is then the
|
||||
collection of state events referenced by the event's state group.
|
||||
|
||||
Hence, every change in the current state causes a new state group to be
|
||||
generated. However, if no change happens (e.g., if we get a message event
|
||||
with only one parent it inherits the state group from its parent.)
|
||||
|
||||
There are three tables:
|
||||
* `state_groups`: Stores group name, first event with in the group and
|
||||
room id.
|
||||
* `event_to_state_groups`: Maps events to state groups.
|
||||
* `state_groups_state`: Maps state group to state events.
|
||||
None of these functions write to the state tables, so are suitable for
|
||||
including in the SlavedStores.
|
||||
"""
|
||||
|
||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||
|
@ -64,21 +54,10 @@ class StateStore(SQLBaseStore):
|
|||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||
self._background_deduplicate_state,
|
||||
)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_INDEX_UPDATE_NAME,
|
||||
self._background_index_state,
|
||||
)
|
||||
self.register_background_index_update(
|
||||
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
||||
index_name="current_state_events_member_index",
|
||||
table="current_state_events",
|
||||
columns=["state_key"],
|
||||
where_clause="type='m.room.member'",
|
||||
super(StateGroupReadStore, self).__init__(db_conn, hs)
|
||||
|
||||
self._state_group_cache = DictionaryCache(
|
||||
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
||||
)
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
|
@ -190,178 +169,6 @@ class StateStore(SQLBaseStore):
|
|||
for group, event_id_map in group_to_ids.iteritems()
|
||||
})
|
||||
|
||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||
txn.execute(
|
||||
"SELECT count(*) FROM state_groups WHERE id = ?",
|
||||
(state_group,)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
return row and row[0]
|
||||
|
||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
|
||||
if context.current_state_ids is None:
|
||||
# AFAIK, this can never happen
|
||||
logger.error(
|
||||
"Non-outlier event %s had current_state_ids==None",
|
||||
event.event_id)
|
||||
continue
|
||||
|
||||
# if the event was rejected, just give it the same state as its
|
||||
# predecessor.
|
||||
if context.rejected:
|
||||
state_groups[event.event_id] = context.prev_group
|
||||
continue
|
||||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
||||
if self._have_persisted_state_group_txn(txn, context.state_group):
|
||||
continue
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
values={
|
||||
"id": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"event_id": event.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
# We persist as a delta if we can, while also ensuring the chain
|
||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||
if context.prev_group:
|
||||
is_in_db = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
keyvalues={"id": context.prev_group},
|
||||
retcol="id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not is_in_db:
|
||||
raise Exception(
|
||||
"Trying to persist state with unpersisted prev_group: %r"
|
||||
% (context.prev_group,)
|
||||
)
|
||||
|
||||
potential_hops = self._count_state_group_hops_txn(
|
||||
txn, context.prev_group
|
||||
)
|
||||
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
values={
|
||||
"state_group": context.state_group,
|
||||
"prev_state_group": context.prev_group,
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.delta_ids.iteritems()
|
||||
],
|
||||
)
|
||||
else:
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.current_state_ids.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
# Prefill the state group cache with this group.
|
||||
# It's fine to use the sequence like this as the state group map
|
||||
# is immutable. (If the map wasn't immutable then this prefill could
|
||||
# race with another update)
|
||||
txn.call_after(
|
||||
self._state_group_cache.update,
|
||||
self._state_group_cache.sequence,
|
||||
key=context.state_group,
|
||||
value=dict(context.current_state_ids),
|
||||
full=True,
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values=[
|
||||
{
|
||||
"state_group": state_group_id,
|
||||
"event_id": event_id,
|
||||
}
|
||||
for event_id, state_group_id in state_groups.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
for event_id, state_group_id in state_groups.iteritems():
|
||||
txn.call_after(
|
||||
self._get_state_group_for_event.prefill,
|
||||
(event_id,), state_group_id
|
||||
)
|
||||
|
||||
def _count_state_group_hops_txn(self, txn, state_group):
|
||||
"""Given a state group, count how many hops there are in the tree.
|
||||
|
||||
This is used to ensure the delta chains don't get too long.
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = ("""
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
SELECT prev_state_group FROM state_group_edges e, state s
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT count(*) FROM state;
|
||||
""")
|
||||
|
||||
txn.execute(sql, (state_group,))
|
||||
row = txn.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0]
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
next_group = state_group
|
||||
count = 0
|
||||
|
||||
while next_group:
|
||||
next_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": next_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if next_group:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_groups_from_groups(self, groups, types):
|
||||
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
|
||||
|
@ -742,6 +549,220 @@ class StateStore(SQLBaseStore):
|
|||
|
||||
defer.returnValue(results)
|
||||
|
||||
|
||||
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
||||
""" Keeps track of the state at a given event.
|
||||
|
||||
This is done by the concept of `state groups`. Every event is a assigned
|
||||
a state group (identified by an arbitrary string), which references a
|
||||
collection of state events. The current state of an event is then the
|
||||
collection of state events referenced by the event's state group.
|
||||
|
||||
Hence, every change in the current state causes a new state group to be
|
||||
generated. However, if no change happens (e.g., if we get a message event
|
||||
with only one parent it inherits the state group from its parent.)
|
||||
|
||||
There are three tables:
|
||||
* `state_groups`: Stores group name, first event with in the group and
|
||||
room id.
|
||||
* `event_to_state_groups`: Maps events to state groups.
|
||||
* `state_groups_state`: Maps state group to state events.
|
||||
"""
|
||||
|
||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
||||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateStore, self).__init__(db_conn, hs)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||
self._background_deduplicate_state,
|
||||
)
|
||||
self.register_background_update_handler(
|
||||
self.STATE_GROUP_INDEX_UPDATE_NAME,
|
||||
self._background_index_state,
|
||||
)
|
||||
self.register_background_index_update(
|
||||
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
||||
index_name="current_state_events_member_index",
|
||||
table="current_state_events",
|
||||
columns=["state_key"],
|
||||
where_clause="type='m.room.member'",
|
||||
)
|
||||
|
||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||
txn.execute(
|
||||
"SELECT count(*) FROM state_groups WHERE id = ?",
|
||||
(state_group,)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
return row and row[0]
|
||||
|
||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
|
||||
if context.current_state_ids is None:
|
||||
# AFAIK, this can never happen
|
||||
logger.error(
|
||||
"Non-outlier event %s had current_state_ids==None",
|
||||
event.event_id)
|
||||
continue
|
||||
|
||||
# if the event was rejected, just give it the same state as its
|
||||
# predecessor.
|
||||
if context.rejected:
|
||||
state_groups[event.event_id] = context.prev_group
|
||||
continue
|
||||
|
||||
state_groups[event.event_id] = context.state_group
|
||||
|
||||
if self._have_persisted_state_group_txn(txn, context.state_group):
|
||||
continue
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
values={
|
||||
"id": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"event_id": event.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
# We persist as a delta if we can, while also ensuring the chain
|
||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||
if context.prev_group:
|
||||
is_in_db = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
keyvalues={"id": context.prev_group},
|
||||
retcol="id",
|
||||
allow_none=True,
|
||||
)
|
||||
if not is_in_db:
|
||||
raise Exception(
|
||||
"Trying to persist state with unpersisted prev_group: %r"
|
||||
% (context.prev_group,)
|
||||
)
|
||||
|
||||
potential_hops = self._count_state_group_hops_txn(
|
||||
txn, context.prev_group
|
||||
)
|
||||
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
values={
|
||||
"state_group": context.state_group,
|
||||
"prev_state_group": context.prev_group,
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.delta_ids.iteritems()
|
||||
],
|
||||
)
|
||||
else:
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
"state_group": context.state_group,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for key, state_id in context.current_state_ids.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
# Prefill the state group cache with this group.
|
||||
# It's fine to use the sequence like this as the state group map
|
||||
# is immutable. (If the map wasn't immutable then this prefill could
|
||||
# race with another update)
|
||||
txn.call_after(
|
||||
self._state_group_cache.update,
|
||||
self._state_group_cache.sequence,
|
||||
key=context.state_group,
|
||||
value=dict(context.current_state_ids),
|
||||
full=True,
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values=[
|
||||
{
|
||||
"state_group": state_group_id,
|
||||
"event_id": event_id,
|
||||
}
|
||||
for event_id, state_group_id in state_groups.iteritems()
|
||||
],
|
||||
)
|
||||
|
||||
for event_id, state_group_id in state_groups.iteritems():
|
||||
txn.call_after(
|
||||
self._get_state_group_for_event.prefill,
|
||||
(event_id,), state_group_id
|
||||
)
|
||||
|
||||
def _count_state_group_hops_txn(self, txn, state_group):
|
||||
"""Given a state group, count how many hops there are in the tree.
|
||||
|
||||
This is used to ensure the delta chains don't get too long.
|
||||
"""
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = ("""
|
||||
WITH RECURSIVE state(state_group) AS (
|
||||
VALUES(?::bigint)
|
||||
UNION ALL
|
||||
SELECT prev_state_group FROM state_group_edges e, state s
|
||||
WHERE s.state_group = e.state_group
|
||||
)
|
||||
SELECT count(*) FROM state;
|
||||
""")
|
||||
|
||||
txn.execute(sql, (state_group,))
|
||||
row = txn.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0]
|
||||
else:
|
||||
return 0
|
||||
else:
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
next_group = state_group
|
||||
count = 0
|
||||
|
||||
while next_group:
|
||||
next_group = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="state_group_edges",
|
||||
keyvalues={"state_group": next_group},
|
||||
retcol="prev_state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if next_group:
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def get_next_state_group(self):
|
||||
return self._state_groups_id_gen.get_next()
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ from ._base import SQLBaseStore
|
|||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
import logging
|
||||
|
@ -234,7 +234,7 @@ class StreamStore(SQLBaseStore):
|
|||
results = {}
|
||||
room_ids = list(room_ids)
|
||||
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
res = yield make_deferred_yieldable(defer.gatherResults([
|
||||
preserve_fn(self.get_room_events_stream_for_room)(
|
||||
room_id, from_key, to_key, limit, order=order,
|
||||
)
|
||||
|
|
|
@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We weight the loclpart most highly, then display name and finally
|
||||
# We weight the localpart most highly, then display name and finally
|
||||
# server name
|
||||
if new_entry:
|
||||
sql = """
|
||||
|
@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
rows = yield self._execute("get_all_rooms", None, sql)
|
||||
defer.returnValue([room_id for room_id, in rows])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_local_users(self):
|
||||
"""Get all local users
|
||||
"""
|
||||
sql = """
|
||||
SELECT name FROM users
|
||||
"""
|
||||
rows = yield self._execute("get_all_local_users", None, sql)
|
||||
defer.returnValue([name for name, in rows])
|
||||
|
||||
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
|
||||
"""Insert entries into the users_who_share_rooms table. The first
|
||||
user should be a local user.
|
||||
|
@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
]
|
||||
}
|
||||
"""
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
join_clause = ""
|
||||
where_clause = "?<>''" # naughty hack to keep the same number of binds
|
||||
else:
|
||||
join_clause = """
|
||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||
LEFT JOIN (
|
||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||
WHERE user_id = ? AND share_private
|
||||
) AS s USING (user_id)
|
||||
"""
|
||||
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
|
||||
|
||||
|
@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
SELECT d.user_id, display_name, avatar_url
|
||||
FROM user_directory_search
|
||||
INNER JOIN user_directory AS d USING (user_id)
|
||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||
LEFT JOIN (
|
||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||
WHERE user_id = ? AND share_private
|
||||
) AS s USING (user_id)
|
||||
%s
|
||||
WHERE
|
||||
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
|
||||
%s
|
||||
AND vector @@ to_tsquery('english', ?)
|
||||
ORDER BY
|
||||
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
|
||||
|
@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
display_name IS NULL,
|
||||
avatar_url IS NULL
|
||||
LIMIT ?
|
||||
"""
|
||||
""" % (join_clause, where_clause)
|
||||
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
search_query = _parse_query_sqlite(search_term)
|
||||
|
@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore):
|
|||
SELECT d.user_id, display_name, avatar_url
|
||||
FROM user_directory_search
|
||||
INNER JOIN user_directory AS d USING (user_id)
|
||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||
LEFT JOIN (
|
||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||
WHERE user_id = ? AND share_private
|
||||
) AS s USING (user_id)
|
||||
%s
|
||||
WHERE
|
||||
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
|
||||
%s
|
||||
AND value MATCH ?
|
||||
ORDER BY
|
||||
rank(matchinfo(user_directory_search)) DESC,
|
||||
display_name IS NULL,
|
||||
avatar_url IS NULL
|
||||
LIMIT ?
|
||||
"""
|
||||
""" % (join_clause, where_clause)
|
||||
args = (user_id, search_query, limit + 1)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
|
@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term):
|
|||
|
||||
# Pull out the individual words, discarding any non-word characters.
|
||||
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
||||
return " & ".join("(%s* | %s)" % (result, result,) for result in results)
|
||||
return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
|
||||
|
||||
|
||||
def _parse_query_postgres(search_term):
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from twisted.internet import defer, reactor
|
||||
|
||||
from .logcontext import (
|
||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||
PreserveLoggingContext, make_deferred_yieldable, preserve_fn
|
||||
)
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
|
||||
|
@ -351,7 +351,7 @@ class ReadWriteLock(object):
|
|||
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
yield curr_writer
|
||||
yield make_deferred_yieldable(curr_writer)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
@ -380,7 +380,7 @@ class ReadWriteLock(object):
|
|||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
|
||||
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
|
|
@ -13,32 +13,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import (
|
||||
PreserveLoggingContext, preserve_context_over_fn
|
||||
)
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def user_left_room(distributor, user, room_id):
|
||||
return preserve_context_over_fn(
|
||||
distributor.fire,
|
||||
"user_left_room", user=user, room_id=room_id
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
distributor.fire("user_left_room", user=user, room_id=room_id)
|
||||
|
||||
|
||||
def user_joined_room(distributor, user, room_id):
|
||||
return preserve_context_over_fn(
|
||||
distributor.fire,
|
||||
"user_joined_room", user=user, room_id=room_id
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
distributor.fire("user_joined_room", user=user, room_id=room_id)
|
||||
|
||||
|
||||
class Distributor(object):
|
||||
|
|
|
@ -261,67 +261,6 @@ class PreserveLoggingContext(object):
|
|||
)
|
||||
|
||||
|
||||
class _PreservingContextDeferred(defer.Deferred):
|
||||
"""A deferred that ensures that all callbacks and errbacks are called with
|
||||
the given logging context.
|
||||
"""
|
||||
def __init__(self, context):
|
||||
self._log_context = context
|
||||
defer.Deferred.__init__(self)
|
||||
|
||||
def addCallbacks(self, callback, errback=None,
|
||||
callbackArgs=None, callbackKeywords=None,
|
||||
errbackArgs=None, errbackKeywords=None):
|
||||
callback = self._wrap_callback(callback)
|
||||
errback = self._wrap_callback(errback)
|
||||
return defer.Deferred.addCallbacks(
|
||||
self, callback,
|
||||
errback=errback,
|
||||
callbackArgs=callbackArgs,
|
||||
callbackKeywords=callbackKeywords,
|
||||
errbackArgs=errbackArgs,
|
||||
errbackKeywords=errbackKeywords,
|
||||
)
|
||||
|
||||
def _wrap_callback(self, f):
|
||||
def g(res, *args, **kwargs):
|
||||
with PreserveLoggingContext(self._log_context):
|
||||
res = f(res, *args, **kwargs)
|
||||
return res
|
||||
return g
|
||||
|
||||
|
||||
def preserve_context_over_fn(fn, *args, **kwargs):
|
||||
"""Takes a function and invokes it with the given arguments, but removes
|
||||
and restores the current logging context while doing so.
|
||||
|
||||
If the result is a deferred, call preserve_context_over_deferred before
|
||||
returning it.
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
res = fn(*args, **kwargs)
|
||||
|
||||
if isinstance(res, defer.Deferred):
|
||||
return preserve_context_over_deferred(res)
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def preserve_context_over_deferred(deferred, context=None):
|
||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
||||
be invoked with the current context.
|
||||
|
||||
Deprecated: this almost certainly doesn't do want you want, ie make
|
||||
the deferred follow the synapse logcontext rules: try
|
||||
``make_deferred_yieldable`` instead.
|
||||
"""
|
||||
if context is None:
|
||||
context = LoggingContext.current_context()
|
||||
d = _PreservingContextDeferred(context)
|
||||
deferred.chainDeferred(d)
|
||||
return d
|
||||
|
||||
|
||||
def preserve_fn(f):
|
||||
"""Wraps a function, to ensure that the current context is restored after
|
||||
return from the function, and that the sentinel context is set once the
|
||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -58,7 +58,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
|
|||
always_include_ids (set(event_id)): set of event ids to specifically
|
||||
include (unless sender is ignored)
|
||||
"""
|
||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
forgotten = yield make_deferred_yieldable(defer.gatherResults([
|
||||
defer.maybeDeferred(
|
||||
preserve_fn(store.who_forgot_in_room),
|
||||
room_id,
|
||||
|
|
|
@ -36,6 +36,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
|||
id="unique_identifier",
|
||||
url="some_url",
|
||||
token="some_token",
|
||||
hostname="matrix.org", # only used by get_groups_for_user
|
||||
namespaces={
|
||||
ApplicationService.NS_USERS: [],
|
||||
ApplicationService.NS_ROOMS: [],
|
||||
|
|
|
@ -58,7 +58,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
|
||||
self.mock_federation_resource = MockHttpResource()
|
||||
|
||||
mock_notifier = Mock(spec=["on_new_event"])
|
||||
mock_notifier = Mock()
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
self.auth = Mock(spec=[])
|
||||
|
@ -76,6 +76,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
"get_devices_by_remote",
|
||||
# Bits that user_directory needs
|
||||
"get_user_directory_stream_pos",
|
||||
"get_current_state_deltas",
|
||||
]),
|
||||
state_handler=self.state_handler,
|
||||
handlers=None,
|
||||
|
@ -122,6 +125,15 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
return set(str(u) for u in self.room_members)
|
||||
self.state_handler.get_current_user_in_room = get_current_user_in_room
|
||||
|
||||
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||
defer.succeed(1)
|
||||
)
|
||||
|
||||
self.datastore.get_current_state_deltas.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
self.auth.check_joined_room = check_joined_room
|
||||
|
||||
self.datastore.get_to_device_stream_token = lambda: 0
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from twisted.python import failure
|
||||
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
|
||||
from twisted.internet import defer
|
||||
from mock import Mock
|
||||
from tests import unittest
|
||||
|
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
side_effect=lambda x: self.appservice)
|
||||
)
|
||||
|
||||
self.auth_result = (False, None, None, None)
|
||||
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
||||
self.auth_handler = Mock(
|
||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||
get_session_data=Mock(return_value=None)
|
||||
|
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
self.request.args = {
|
||||
"access_token": "i_am_an_app_service"
|
||||
}
|
||||
|
||||
self.request_data = json.dumps({
|
||||
"username": "kermit"
|
||||
})
|
||||
|
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
"device_id": device_id,
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (True, None, {
|
||||
self.auth_result = (None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
}, None)
|
||||
|
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
|||
"password": "monkey"
|
||||
})
|
||||
self.registration_handler.check_username = Mock(return_value=True)
|
||||
self.auth_result = (True, None, {
|
||||
self.auth_result = (None, {
|
||||
"username": "kermit",
|
||||
"password": "monkey"
|
||||
}, None)
|
||||
|
|
Loading…
Reference in a new issue