Skip to content

Commit a7a70fd

Browse files
authored
Echo notifications (#1385)
* Echo notifications * Blackify * Clean up a bit * Add test * Blackify * Update changelog * Fix tests
1 parent 9f114c4 commit a7a70fd

File tree

7 files changed

+64
-5
lines changed

7 files changed

+64
-5
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Contributors:
132132
* Sharon Yogev (sharonyogev)
133133
* Hollis Wu (holi0317)
134134
* Antonio Aguilar (crazybolillo)
135+
* Andrew M. MacFie (amacfie)
135136

136137
Creator:
137138
--------

changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Upcoming
44
Features:
55
---------
66
* Support `PGAPPNAME` as an environment variable and `--application-name` as a command line argument.
7+
* Show Postgres notifications
78

89
Bug fixes:
910
----------

pgcli/main.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373

7474
from getpass import getuser
7575

76-
from psycopg import OperationalError, InterfaceError
76+
from psycopg import OperationalError, InterfaceError, Notify
7777
from psycopg.conninfo import make_conninfo, conninfo_to_dict
7878

7979
from collections import namedtuple
@@ -128,6 +128,15 @@ class PgCliQuitError(Exception):
128128
pass
129129

130130

131+
def notify_callback(notify: Notify):
132+
click.secho(
133+
'Notification received on channel "{}" (PID {}):\n{}'.format(
134+
notify.channel, notify.pid, notify.payload
135+
),
136+
fg="green",
137+
)
138+
139+
131140
class PGCli:
132141
default_prompt = "\\u@\\h:\\d> "
133142
max_len_prompt = 30
@@ -660,7 +669,16 @@ def should_ask_for_password(exc):
660669
# prompt for a password (no -w flag), prompt for a passwd and try again.
661670
try:
662671
try:
663-
pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
672+
pgexecute = PGExecute(
673+
database,
674+
user,
675+
passwd,
676+
host,
677+
port,
678+
dsn,
679+
notify_callback,
680+
**kwargs,
681+
)
664682
except (OperationalError, InterfaceError) as e:
665683
if should_ask_for_password(e):
666684
passwd = click.prompt(
@@ -670,7 +688,14 @@ def should_ask_for_password(exc):
670688
type=str,
671689
)
672690
pgexecute = PGExecute(
673-
database, user, passwd, host, port, dsn, **kwargs
691+
database,
692+
user,
693+
passwd,
694+
host,
695+
port,
696+
dsn,
697+
notify_callback,
698+
**kwargs,
674699
)
675700
else:
676701
raise e

pgcli/pgexecute.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def __init__(
167167
host=None,
168168
port=None,
169169
dsn=None,
170+
notify_callback=None,
170171
**kwargs,
171172
):
172173
self._conn_params = {}
@@ -179,6 +180,7 @@ def __init__(
179180
self.port = None
180181
self.server_version = None
181182
self.extra_args = None
183+
self.notify_callback = notify_callback
182184
self.connect(database, user, password, host, port, dsn, **kwargs)
183185
self.reset_expanded = None
184186

@@ -237,6 +239,9 @@ def connect(
237239
self.conn = conn
238240
self.conn.autocommit = True
239241

242+
if self.notify_callback is not None:
243+
self.conn.add_notify_handler(self.notify_callback)
244+
240245
# When we connect using a DSN, we don't really know what db,
241246
# user, etc. we connected to. Let's read it.
242247
# Note: moved this after setting autocommit because of #664.

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
db_connection,
1010
drop_tables,
1111
)
12+
import pgcli.main
1213
import pgcli.pgexecute
1314

1415

@@ -37,6 +38,7 @@ def executor(connection):
3738
password=POSTGRES_PASSWORD,
3839
port=POSTGRES_PORT,
3940
dsn=None,
41+
notify_callback=pgcli.main.notify_callback,
4042
)
4143

4244

tests/test_main.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import platform
3+
import re
34
from unittest import mock
45

56
import pytest
@@ -13,6 +14,7 @@
1314
obfuscate_process_password,
1415
duration_in_words,
1516
format_output,
17+
notify_callback,
1618
PGCli,
1719
OutputSettings,
1820
COLOR_CODE_REGEX,
@@ -432,6 +434,7 @@ def test_pg_service_file(tmpdir):
432434
"b_host",
433435
"5435",
434436
"",
437+
notify_callback,
435438
application_name="pgcli",
436439
)
437440
del os.environ["PGPASSWORD"]
@@ -487,7 +490,7 @@ def test_application_name_db_uri(tmpdir):
487490
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
488491
cli.connect_uri("postgres://[email protected]/?application_name=cow")
489492
mock_pgexecute.assert_called_with(
490-
"bar", "bar", "", "baz.com", "", "", application_name="cow"
493+
"bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow"
491494
)
492495

493496

@@ -514,3 +517,23 @@ def test_application_name_db_uri(tmpdir):
514517
)
515518
def test_duration_in_words(duration_in_seconds, words):
516519
assert duration_in_words(duration_in_seconds) == words
520+
521+
522+
@dbtest
523+
def test_notifications(executor):
524+
run(executor, "listen chan1")
525+
526+
with mock.patch("pgcli.main.click.secho") as mock_secho:
527+
run(executor, "notify chan1, 'testing1'")
528+
mock_secho.assert_called()
529+
arg = mock_secho.call_args_list[0].args[0]
530+
assert re.match(
531+
r'Notification received on channel "chan1" \(PID \d+\):\ntesting1',
532+
arg,
533+
)
534+
535+
run(executor, "unlisten chan1")
536+
537+
with mock.patch("pgcli.main.click.secho") as mock_secho:
538+
run(executor, "notify chan1, 'testing2'")
539+
mock_secho.assert_not_called()

tests/test_ssh_tunnel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from click.testing import CliRunner
77
from sshtunnel import SSHTunnelForwarder
88

9-
from pgcli.main import cli, PGCli
9+
from pgcli.main import cli, notify_callback, PGCli
1010
from pgcli.pgexecute import PGExecute
1111

1212

@@ -61,6 +61,7 @@ def test_ssh_tunnel(
6161
"127.0.0.1",
6262
pgcli.ssh_tunnel.local_bind_ports[0],
6363
"",
64+
notify_callback,
6465
)
6566
mock_ssh_tunnel_forwarder.reset_mock()
6667
mock_pgexecute.reset_mock()
@@ -96,6 +97,7 @@ def test_ssh_tunnel(
9697
"127.0.0.1",
9798
pgcli.ssh_tunnel.local_bind_ports[0],
9899
"",
100+
notify_callback,
99101
)
100102
mock_ssh_tunnel_forwarder.reset_mock()
101103
mock_pgexecute.reset_mock()

0 commit comments

Comments
 (0)