Skip to content

Commit

Permalink
WIP: Pump disappearing messages from db
Browse files Browse the repository at this point in the history
  • Loading branch information
hifi committed Dec 11, 2023
1 parent 562f646 commit 493955c
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 13 deletions.
50 changes: 40 additions & 10 deletions mautrix_telegram/db/disappearing_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import TYPE_CHECKING, ClassVar

import asyncpg
import time

from mautrix.bridge import AbstractDisappearingMessage
from mautrix.types import EventID, RoomID
Expand All @@ -27,6 +28,7 @@


class DisappearingMessage(AbstractDisappearingMessage):
unqueued_ts: int | None = None
db: ClassVar[Database] = fake_db

async def insert(self) -> None:
Expand All @@ -50,6 +52,40 @@ async def delete(self) -> None:
def _from_row(cls, row: asyncpg.Record) -> DisappearingMessage:
return cls(**row)

"""
Get all scheduled messages that will expire in given seconds that haven't yet been unqueued.
This will also stamp them in the database for being unqueued so every time this method is called
there should be a unique set of events. If seconds is None then all events will be returned
regardless of being requested before.
The first call on startup should be with None and subsequent with the previous value.
"""
@classmethod
async def unqueue_expiring(cls, seconds: int | None = None) -> list[DisappearingMessage]:
unqueued_ts = int(time.time() * 1000)

rows = None
if seconds is None:
q = """
SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message
WHERE expiration_ts <= $1
"""
rows = await cls.db.fetch(q, unqueued_ts)
else:
q = """
SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message
WHERE expiration_ts <= $1 AND (unqueued_ts IS NULL OR unqueued_ts < $2)
"""
rows = await cls.db.fetch(q, unqueued_ts + (seconds * 1000), unqueued_ts)

msgs = [cls._from_row(r) for r in rows]
for msg in msgs:
msg.unqueued_ts = unqueued_ts
await msg.update()

return msgs

@classmethod
async def get(cls, room_id: RoomID, event_id: EventID) -> DisappearingMessage | None:
q = """
Expand All @@ -63,16 +99,10 @@ async def get(cls, room_id: RoomID, event_id: EventID) -> DisappearingMessage |

@classmethod
async def get_all_scheduled(cls) -> list[DisappearingMessage]:
q = """
SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message
WHERE expiration_ts IS NOT NULL
"""
return [cls._from_row(r) for r in await cls.db.fetch(q)]
# Stubbed because we pump with unqueue_expiring
return []

@classmethod
async def get_unscheduled_for_room(cls, room_id: RoomID) -> list[DisappearingMessage]:
q = """
SELECT room_id, event_id, expiration_seconds, expiration_ts FROM disappearing_message
WHERE room_id = $1 AND expiration_ts IS NULL
"""
return [cls._from_row(r) for r in await cls.db.fetch(q, room_id)]
# Stubbed because we pump with unqueue_expiring
return []
1 change: 1 addition & 0 deletions mautrix_telegram/db/upgrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
v16_backfill_type,
v17_message_find_recent,
v18_puppet_contact_info_set,
v19_disappearing_message_unqueue,
)
4 changes: 3 additions & 1 deletion mautrix_telegram/db/upgrade/v00_latest_revision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection, Scheme

latest_version = 18
latest_version = 19


async def create_latest_tables(conn: Connection, scheme: Scheme) -> int:
Expand Down Expand Up @@ -92,10 +92,12 @@ async def create_latest_tables(conn: Connection, scheme: Scheme) -> int:
event_id TEXT,
expiration_seconds BIGINT,
expiration_ts BIGINT,
unqueued_ts BIGINT,
PRIMARY KEY (room_id, event_id)
)"""
)
await conn.execute("CREATE INDEX disappearing_message_expiration_ts ON disappearing_message(expiration_ts)")
await conn.execute(
"""CREATE TABLE puppet (
id BIGINT PRIMARY KEY,
Expand Down
26 changes: 26 additions & 0 deletions mautrix_telegram/db/upgrade/v19_disappearing_message_unqueue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection

from . import upgrade_table


@upgrade_table.register(description="Add index on disappearing_message expiration_ts, unqueued_ts column")
async def upgrade_v19(conn: Connection) -> None:
await conn.execute(
"ALTER TABLE disappearing_message ADD COLUMN unqueued_ts BIGINT"
)
await conn.execute("CREATE INDEX disappearing_message_expiration_ts ON disappearing_message(expiration_ts)")
40 changes: 39 additions & 1 deletion mautrix_telegram/portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ class Portal(DBPortal, BasePortal):

_msg_conv: putil.TelegramMessageConverter

_disappearing_event: asyncio.Event

def __init__(
self,
tgid: TelegramID,
Expand Down Expand Up @@ -468,6 +470,42 @@ def set_dm_room_metadata(self) -> bool:
or (self.encrypted and self.private_chat_portal_meta != "never")
)

@classmethod
async def _disappearing_message_loop(cls, seconds: int | None = None) -> None:
try:
seconds = None
while True:
print("fetching disappearing")
cls._disappearing_event.clear()
msgs = await cls.disappearing_msg_class.unqueue_expiring(seconds)
print(f"got {len(msgs)} rows")
for msg in msgs:
print("handling disappear thing")
portal = await cls.bridge.get_portal(msg.room_id)
if portal and portal.mxid:
background_task.create(portal._disappear_event(msg))
else:
await msg.delete()

try:
await asyncio.wait_for(cls._disappearing_event.wait(), 10)
except TimeoutError:
pass

seconds = 10
except RuntimeError:
return

@classmethod
async def restart_scheduled_disappearing(cls) -> None:
cls._disappearing_event = asyncio.Event()
background_task.create(cls._disappearing_message_loop())

@classmethod
async def notify_disappearing_message_loop(cls) -> None:
print("notifying disappear loop")
cls._disappearing_event.set()

@classmethod
def init_cls(cls, bridge: "TelegramBridge") -> None:
BasePortal.bridge = bridge
Expand Down Expand Up @@ -3531,7 +3569,7 @@ async def _mark_disappearing(
)
await dm.insert()
if expires_at:
background_task.create(self._disappear_event(dm))
Portal.notify_disappearing_message_loop()

async def _create_room_on_action(
self, source: au.AbstractUser, action: TypeMessageAction
Expand Down
7 changes: 6 additions & 1 deletion mautrix_telegram/version.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .get_version import git_revision, git_tag, linkified_version, version
# Generated in setup.py

git_tag = None
git_revision = "e3a067c2"
version = "0.11.3+dev.e3a067c2"
linkified_version = "0.11.3+dev.[e3a067c2](https://github.com/mautrix/telegram/commit/e3a067c27aa3d9dd5e82db307218cc66c8356ddd)"

0 comments on commit 493955c

Please sign in to comment.