Skip to content

Commit ece7ff6

Browse files
(TS) Fix PGVector implementation, where vector distance was inverted. (#4944)
1 parent 30ce028 commit ece7ff6

2 files changed

Lines changed: 129 additions & 1 deletion

File tree

mem0-ts/src/oss/src/vector_stores/pgvector.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ export class PGVector implements VectorStore {
251251
return result.rows.map((row) => ({
252252
id: row.id,
253253
payload: row.payload,
254-
score: row.distance,
254+
score: Math.max(0, Math.min(1, 1 - Number(row.distance))),
255255
}));
256256
}
257257

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/// <reference types="jest" />
2+
3+
const searchRows = [
4+
{
5+
id: "a",
6+
payload: { data: "exactly x-axis" },
7+
distance: "0",
8+
},
9+
{
10+
id: "b",
11+
payload: { data: "close to x-axis" },
12+
distance: "0.006116251198662548",
13+
},
14+
{
15+
id: "c",
16+
payload: { data: "y-axis" },
17+
distance: "1",
18+
},
19+
{
20+
id: "d",
21+
payload: { data: "opposite x-axis" },
22+
distance: "2",
23+
},
24+
];
25+
26+
function mockPgQuery(sql: string) {
27+
if (sql.includes("SELECT 1 FROM pg_database")) {
28+
return { rows: [{ "?column?": 1 }] };
29+
}
30+
31+
if (sql.includes("FROM information_schema.tables")) {
32+
return { rows: [{ table_name: "memories" }] };
33+
}
34+
35+
if (sql.includes("vector <=> $1::vector AS distance")) {
36+
return { rows: searchRows };
37+
}
38+
39+
return { rows: [] };
40+
}
41+
42+
jest.mock("pg", () => {
43+
const clients: any[] = [];
44+
45+
const Client = jest.fn().mockImplementation((config: any) => {
46+
const client = {
47+
config,
48+
connect: jest.fn().mockResolvedValue(undefined),
49+
end: jest.fn().mockResolvedValue(undefined),
50+
query: jest
51+
.fn()
52+
.mockImplementation(async (sql: string) => mockPgQuery(sql)),
53+
};
54+
55+
clients.push(client);
56+
return client;
57+
});
58+
59+
return {
60+
__esModule: true,
61+
default: { Client },
62+
Client,
63+
__mock: { Client, clients },
64+
};
65+
});
66+
67+
import { PGVector } from "../src/vector_stores/pgvector";
68+
69+
describe("PGVector - search()", () => {
70+
beforeEach(() => {
71+
const pg = require("pg");
72+
pg.__mock.Client.mockClear();
73+
pg.__mock.clients.length = 0;
74+
});
75+
76+
test("converts cosine distance into a clamped similarity score", async () => {
77+
const store = new PGVector({
78+
collectionName: "memories",
79+
user: "postgres",
80+
password: "postgres",
81+
host: "localhost",
82+
port: 5432,
83+
embeddingModelDims: 3,
84+
dimension: 3,
85+
} as any);
86+
87+
await store.initialize();
88+
89+
const results = await store.search([1, 0, 0], 4);
90+
91+
expect(results).toEqual([
92+
{
93+
id: "a",
94+
payload: { data: "exactly x-axis" },
95+
score: 1,
96+
},
97+
{
98+
id: "b",
99+
payload: { data: "close to x-axis" },
100+
score: 0.9938837488013375,
101+
},
102+
{
103+
id: "c",
104+
payload: { data: "y-axis" },
105+
score: 0,
106+
},
107+
{
108+
id: "d",
109+
payload: { data: "opposite x-axis" },
110+
score: 0,
111+
},
112+
]);
113+
114+
const pg = require("pg");
115+
expect(pg.__mock.Client).toHaveBeenCalledTimes(2);
116+
117+
const activeClient = pg.__mock.clients[1];
118+
expect(activeClient.query).toHaveBeenCalledWith(
119+
expect.stringContaining("vector <=> $1::vector AS distance"),
120+
["[1,0,0]", 4],
121+
);
122+
123+
for (const result of results) {
124+
expect(result.score).toBeGreaterThanOrEqual(0);
125+
expect(result.score).toBeLessThanOrEqual(1);
126+
}
127+
});
128+
});

0 commit comments

Comments
 (0)