[proxy]
github.com
—
← back
|
site home
|
direct (HTTPS) ↗
|
proxy home
|
◑ dark
◐ light
pgvector
/
pgvector-python
Public
Notifications
You must be signed in to change notification settings
Fork
89
Star
1.4k
Files
Expand file tree
master
/
example.py
Copy path
Blame
More file actions
Blame
More file actions
Latest commit
History
History
History
47 lines (36 loc) · 1.44 KB
master
/
example.py
Top
File metadata and controls
Code
Blame
47 lines (36 loc) · 1.44 KB
Raw
Copy raw file
Download raw file
Open symbols panel
Edit and raw actions
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from
datasets
import
load_dataset
from
imagehash
import
phash
import
matplotlib
.
pyplot
as
plt
from
pgvector
.
psycopg
import
register_vector
,
Bit
import
psycopg
def
hash_image
(
img
):
return
''
.
join
([
'1'
if
v
else
'0'
for
v
in
phash
(
img
).
hash
.
flatten
()])
conn
=
psycopg
.
connect
(
dbname
=
'pgvector_example'
,
autocommit
=
True
)
conn
.
execute
(
'CREATE EXTENSION IF NOT EXISTS vector'
)
register_vector
(
conn
)
conn
.
execute
(
'DROP TABLE IF EXISTS images'
)
conn
.
execute
(
'CREATE TABLE images (id bigserial PRIMARY KEY, hash bit(64))'
)
print
(
'Loading dataset'
)
dataset
=
load_dataset
(
'mnist'
)
print
(
'Generating hashes'
)
images
=
[{
'hash'
:
hash_image
(
row
[
'image'
])}
for
row
in
dataset
[
'train'
]]
print
(
'Storing hashes'
)
cur
=
conn
.
cursor
()
with
cur
.
copy
(
'COPY images (hash) FROM STDIN'
)
as
copy
:
for
image
in
images
:
copy
.
write_row
([
Bit
(
image
[
'hash'
])])
print
(
'Querying hashes'
)
results
=
[]
for
i
in
range
(
5
):
image
=
dataset
[
'test'
][
i
][
'image'
]
result
=
conn
.
execute
(
'SELECT id FROM images ORDER BY hash <~> %s LIMIT 5'
, (
hash_image
(
image
),)).
fetchall
()
nearest_images
=
[
dataset
[
'train'
][
row
[
0
]
-
1
][
'image'
]
for
row
in
result
]
results
.
append
([
image
]
+
nearest_images
)
print
(
'Showing results (first column is query image)'
)
fig
,
axs
=
plt
.
subplots
(
len
(
results
),
len
(
results
[
0
]))
for
i
,
result
in
enumerate
(
results
):
for
j
,
image
in
enumerate
(
result
):
ax
=
axs
[
i
,
j
]
ax
.
imshow
(
image
)
ax
.
set_axis_off
()
plt
.
show
(
block
=
True
)