-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy pathedge_operations.py
More file actions
741 lines (635 loc) · 28.4 KB
/
edge_operations.py
File metadata and controls
741 lines (635 loc) · 28.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from datetime import datetime
from time import time
from pydantic import BaseModel
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
EpisodicEdge,
create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import Edge as ExtractedEdge
from graphiti_core.prompts.extract_edges import ExtractedEdges
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode],
episode_uuid: str,
created_at: datetime,
) -> list[EpisodicEdge]:
episodic_edges: list[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode_uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=node.group_id,
)
for node in entity_nodes
]
logger.debug(f'Built {len(episodic_edges)} episodic edges')
return episodic_edges
def build_community_edges(
entity_nodes: list[EntityNode],
community_node: CommunityNode,
created_at: datetime,
) -> list[CommunityEdge]:
edges: list[CommunityEdge] = [
CommunityEdge(
source_node_uuid=community_node.uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=community_node.group_id,
)
for node in entity_nodes
]
return edges
async def extract_edges(
clients: GraphitiClients,
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
edge_type_map: dict[tuple[str, str], list[str]],
group_id: str = '',
edge_types: dict[str, type[BaseModel]] | None = None,
custom_extraction_instructions: str | None = None,
) -> list[EntityEdge]:
start = time()
extract_edges_max_tokens = 16384
llm_client = clients.llm_client
# Build mapping from edge type name to list of valid signatures
edge_type_signatures_map: dict[str, list[tuple[str, str]]] = {}
for signature, edge_type_names in edge_type_map.items():
for edge_type in edge_type_names:
if edge_type not in edge_type_signatures_map:
edge_type_signatures_map[edge_type] = []
edge_type_signatures_map[edge_type].append(signature)
edge_types_context = (
[
{
'fact_type_name': type_name,
'fact_type_signatures': edge_type_signatures_map.get(
type_name, [('Entity', 'Entity')]
),
'fact_type_description': type_model.__doc__,
}
for type_name, type_model in edge_types.items()
]
if edge_types is not None
else []
)
# Build name-to-node mapping for validation
name_to_node: dict[str, EntityNode] = {node.name: node for node in nodes}
# Prepare context for LLM
context = {
'episode_content': episode.content,
'nodes': [{'name': node.name, 'entity_types': node.labels} for node in nodes],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
'edge_types': edge_types_context,
'custom_extraction_instructions': custom_extraction_instructions or '',
}
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
prompt_name='extract_edges.edge',
)
all_edges_data = ExtractedEdges(**llm_response).edges
# Validate entity names
edges_data: list[ExtractedEdge] = []
for edge_data in all_edges_data:
source_name = edge_data.source_entity_name
target_name = edge_data.target_entity_name
# Validate LLM-returned names exist in the nodes list
if source_name not in name_to_node:
logger.warning(
'Source entity not found in nodes for edge relation: %s',
edge_data.relation_type,
)
continue
if target_name not in name_to_node:
logger.warning(
'Target entity not found in nodes for edge relation: %s',
edge_data.relation_type,
)
continue
edges_data.append(edge_data)
end = time()
logger.debug(f'Extracted {len(edges_data)} new edges in {(end - start) * 1000:.0f} ms')
if len(edges_data) == 0:
return []
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
# Validate Edge Date information
valid_at = edge_data.valid_at
invalid_at = edge_data.invalid_at
valid_at_datetime = None
invalid_at_datetime = None
# Filter out empty edges
if not edge_data.fact.strip():
continue
# Names already validated above
source_node = name_to_node.get(edge_data.source_entity_name)
target_node = name_to_node.get(edge_data.target_entity_name)
if source_node is None or target_node is None:
logger.warning('Could not find source or target node for extracted edge')
continue
source_node_uuid = source_node.uuid
target_node_uuid = target_node.uuid
if valid_at:
try:
valid_at_datetime = ensure_utc(
datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
if invalid_at:
try:
invalid_at_datetime = ensure_utc(
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
edge = EntityEdge(
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
name=edge_data.relation_type,
group_id=group_id,
fact=edge_data.fact,
episodes=[episode.uuid],
created_at=utc_now(),
valid_at=valid_at_datetime,
invalid_at=invalid_at_datetime,
)
edges.append(edge)
logger.debug(
f'Created new edge {edge.uuid} from {edge.source_node_uuid} to {edge.target_node_uuid}'
)
logger.debug(f'Extracted edges: {[e.uuid for e in edges]}')
return edges
async def resolve_extracted_edges(
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
episode: EpisodicNode,
entities: list[EntityNode],
edge_types: dict[str, type[BaseModel]],
edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge], list[EntityEdge]]:
"""Resolve extracted edges against existing graph context.
Returns
-------
tuple[list[EntityEdge], list[EntityEdge], list[EntityEdge]]
A tuple of (resolved_edges, invalidated_edges, new_edges) where:
- resolved_edges: All edges after resolution (may include existing edges if duplicates found)
- invalidated_edges: Edges that were invalidated/contradicted by new information
- new_edges: Only edges that are new to the graph (not duplicates of existing edges)
"""
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
seen: dict[tuple[str, str, str], EntityEdge] = {}
deduplicated_edges: list[EntityEdge] = []
for edge in extracted_edges:
key = (
edge.source_node_uuid,
edge.target_node_uuid,
_normalize_string_exact(edge.fact),
)
if key not in seen:
seen[key] = edge
deduplicated_edges.append(edge)
extracted_edges = deduplicated_edges
driver = clients.driver
llm_client = clients.llm_client
embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges)
forward_edges_list: list[list[EntityEdge]] = await semaphore_gather(
*[
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges
]
)
inverse_edges_list: list[list[EntityEdge]] = await semaphore_gather(
*[
EntityEdge.get_between_nodes(driver, edge.target_node_uuid, edge.source_node_uuid)
for edge in extracted_edges
]
)
valid_edges_list: list[list[EntityEdge]] = []
for forward_edges, inverse_edges in zip(forward_edges_list, inverse_edges_list, strict=True):
seen_uuids: set[str] = set()
combined: list[EntityEdge] = []
for edge in [*forward_edges, *inverse_edges]:
if edge.uuid not in seen_uuids:
seen_uuids.add(edge.uuid)
combined.append(edge)
valid_edges_list.append(combined)
related_edges_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
)
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
]
)
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
for extracted_edge in extracted_edges
]
)
# Remove duplicates: if an edge appears in both duplicate candidates and invalidation candidates,
# keep it only in duplicate candidates
edge_invalidation_candidates: list[list[EntityEdge]] = []
for related_edges, invalidation_result in zip(
related_edges_lists, edge_invalidation_candidate_results, strict=True
):
related_uuids = {edge.uuid for edge in related_edges}
deduplicated = [
edge for edge in invalidation_result.edges if edge.uuid not in related_uuids
]
edge_invalidation_candidates.append(deduplicated)
logger.debug(
f'Related edges: {[e.uuid for edges_lst in related_edges_lists for e in edges_lst]}'
)
# Build entity hash table
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Collect all node UUIDs referenced by edges that are not in the entities list
referenced_node_uuids = set()
for extracted_edge in extracted_edges:
if extracted_edge.source_node_uuid not in uuid_entity_map:
referenced_node_uuids.add(extracted_edge.source_node_uuid)
if extracted_edge.target_node_uuid not in uuid_entity_map:
referenced_node_uuids.add(extracted_edge.target_node_uuid)
# Fetch missing nodes from the database
if referenced_node_uuids:
missing_nodes = await EntityNode.get_by_uuids(driver, list(referenced_node_uuids))
for node in missing_nodes:
uuid_entity_map[node.uuid] = node
# Determine which edge types are relevant for each edge based on node signatures.
# `edge_types_lst` stores the subset of custom edge definitions whose
# node signature matches each extracted edge.
edge_types_lst: list[dict[str, type[BaseModel]]] = []
for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
source_node_labels = (
source_node.labels + ['Entity'] if source_node is not None else ['Entity']
)
target_node_labels = (
target_node.labels + ['Entity'] if target_node is not None else ['Entity']
)
label_tuples = [
(source_label, target_label)
for source_label in source_node_labels
for target_label in target_node_labels
]
extracted_edge_types = {}
for label_tuple in label_tuples:
type_names = edge_type_map.get(label_tuple, [])
for type_name in type_names:
type_model = edge_types.get(type_name)
if type_model is None:
continue
extracted_edge_types[type_name] = type_model
edge_types_lst.append(extracted_edge_types)
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
await semaphore_gather(
*[
resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
extracted_edge_types,
)
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
extracted_edges,
related_edges_lists,
edge_invalidation_candidates,
edge_types_lst,
strict=True,
)
]
)
)
resolved_edges: list[EntityEdge] = []
invalidated_edges: list[EntityEdge] = []
new_edges: list[EntityEdge] = []
for extracted_edge, result in zip(extracted_edges, results, strict=True):
resolved_edge = result[0]
invalidated_edge_chunk = result[1]
# result[2] is duplicate_edges list
resolved_edges.append(resolved_edge)
invalidated_edges.extend(invalidated_edge_chunk)
# Track edges that are new (not duplicates of existing edges)
# An edge is new if the resolved edge UUID matches the extracted edge UUID
if resolved_edge.uuid == extracted_edge.uuid:
new_edges.append(resolved_edge)
logger.debug(f'Resolved edges: {[e.uuid for e in resolved_edges]}')
logger.debug(f'New edges (non-duplicates): {[e.uuid for e in new_edges]}')
await semaphore_gather(
create_entity_edge_embeddings(embedder, resolved_edges),
create_entity_edge_embeddings(embedder, invalidated_edges),
)
return resolved_edges, invalidated_edges, new_edges
def resolve_edge_contradictions(
resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
) -> list[EntityEdge]:
if len(invalidation_candidates) == 0:
return []
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates:
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
edge_invalid_at_utc = ensure_utc(edge.invalid_at)
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
edge_valid_at_utc = ensure_utc(edge.valid_at)
resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
if (
edge_invalid_at_utc is not None
and resolved_edge_valid_at_utc is not None
and edge_invalid_at_utc <= resolved_edge_valid_at_utc
) or (
edge_valid_at_utc is not None
and resolved_edge_invalid_at_utc is not None
and resolved_edge_invalid_at_utc <= edge_valid_at_utc
):
continue
# New edge invalidates edge
elif (
edge_valid_at_utc is not None
and resolved_edge_valid_at_utc is not None
and edge_valid_at_utc < resolved_edge_valid_at_utc
):
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
invalidated_edges.append(edge)
return invalidated_edges
async def resolve_extracted_edge(
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
"""Resolve an extracted edge against existing graph context.
Parameters
----------
llm_client : LLMClient
Client used to invoke the LLM for deduplication and attribute extraction.
extracted_edge : EntityEdge
Newly extracted edge whose canonical representation is being resolved.
related_edges : list[EntityEdge]
Candidate edges with identical endpoints used for duplicate detection.
existing_edges : list[EntityEdge]
Broader set of edges evaluated for contradiction / invalidation.
episode : EpisodicNode
Episode providing content context when extracting edge attributes.
edge_type_candidates : dict[str, type[BaseModel]] | None
Custom edge types permitted for the current source/target signature.
Returns
-------
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
The resolved edge, any duplicates, and edges to invalidate.
"""
if len(related_edges) == 0 and len(existing_edges) == 0:
# Still extract custom attributes even when no dedup/invalidation is needed
edge_model = (
edge_type_candidates.get(extracted_edge.name) if edge_type_candidates else None
)
if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_context = {
'fact': extracted_edge.fact,
'reference_time': episode.valid_at if episode is not None else None,
'existing_attributes': extracted_edge.attributes,
}
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
prompt_name='extract_edges.extract_attributes',
)
extracted_edge.attributes = edge_attributes_response
return extracted_edge, [], []
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
normalized_fact = _normalize_string_exact(extracted_edge.fact)
for edge in related_edges:
if (
edge.source_node_uuid == extracted_edge.source_node_uuid
and edge.target_node_uuid == extracted_edge.target_node_uuid
and _normalize_string_exact(edge.fact) == normalized_fact
):
resolved = edge
if episode is not None and episode.uuid not in resolved.episodes:
resolved.episodes.append(episode.uuid)
return resolved, [], []
start = time()
# Prepare context for LLM with continuous indexing
related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
# Invalidation candidates start where duplicate candidates end
invalidation_idx_offset = len(related_edges)
invalidation_edge_candidates_context = [
{'idx': invalidation_idx_offset + i, 'fact': existing_edge.fact}
for i, existing_edge in enumerate(existing_edges)
]
context = {
'existing_edges': related_edges_context,
'new_edge': extracted_edge.fact,
'edge_invalidation_candidates': invalidation_edge_candidates_context,
}
if related_edges or existing_edges:
logger.debug(
'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
len(related_edges),
f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
len(existing_edges),
f' (idx {invalidation_idx_offset}-{invalidation_idx_offset + len(existing_edges) - 1})'
if existing_edges
else '',
)
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.resolve_edge(context),
response_model=EdgeDuplicate,
model_size=ModelSize.small,
prompt_name='dedupe_edges.resolve_edge',
)
response_object = EdgeDuplicate(**llm_response)
duplicate_facts = response_object.duplicate_facts
# Validate duplicate_facts are in valid range for EXISTING FACTS
invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
if invalid_duplicates:
logger.warning(
'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
invalid_duplicates,
len(related_edges) - 1,
)
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
resolved_edge = extracted_edge
for duplicate_fact_id in duplicate_fact_ids:
resolved_edge = related_edges[duplicate_fact_id]
break
if duplicate_fact_ids and episode is not None:
resolved_edge.episodes.append(episode.uuid)
# Process contradicted facts (continuous indexing across both lists)
contradicted_facts: list[int] = response_object.contradicted_facts
invalidation_candidates: list[EntityEdge] = []
# Only process contradictions if there are edges to check against
if related_edges or existing_edges:
max_valid_idx = len(related_edges) + len(existing_edges) - 1
invalid_contradictions = [i for i in contradicted_facts if i < 0 or i > max_valid_idx]
if invalid_contradictions:
logger.warning(
'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d)',
invalid_contradictions,
max_valid_idx,
)
# Split contradicted facts into those from related_edges vs existing_edges based on offset
for idx in contradicted_facts:
if 0 <= idx < len(related_edges):
# From EXISTING FACTS (duplicate candidates)
invalidation_candidates.append(related_edges[idx])
elif invalidation_idx_offset <= idx <= max_valid_idx:
# From FACT INVALIDATION CANDIDATES (adjust index by offset)
invalidation_candidates.append(existing_edges[idx - invalidation_idx_offset])
# Only extract structured attributes if the edge's relation_type matches an allowed custom type
# AND the edge model exists for this node pair signature
edge_model = edge_type_candidates.get(resolved_edge.name) if edge_type_candidates else None
if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_context = {
'fact': resolved_edge.fact,
'reference_time': episode.valid_at if episode is not None else None,
'existing_attributes': resolved_edge.attributes,
}
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
prompt_name='extract_edges.extract_attributes',
)
resolved_edge.attributes = edge_attributes_response
else:
resolved_edge.attributes = {}
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.uuid} -> {resolved_edge.uuid}, in {(end - start) * 1000} ms'
)
now = utc_now()
if resolved_edge.invalid_at and not resolved_edge.expired_at:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
for candidate in invalidation_candidates:
candidate_valid_at_utc = ensure_utc(candidate.valid_at)
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
if (
candidate_valid_at_utc is not None
and resolved_edge_valid_at_utc is not None
and candidate_valid_at_utc > resolved_edge_valid_at_utc
):
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
resolved_edge, invalidation_candidates
)
duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
return resolved_edge, invalidated_edges, duplicate_edges
async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:
if not duplicates_node_tuples:
return []
duplicate_nodes_map = {
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
}
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_nodes = [
{'source': source.uuid, 'target': target.uuid}
for source, target in duplicates_node_tuples
]
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=duplicate_nodes,
routing_='r',
)
else:
if driver.provider == GraphProvider.KUZU:
query = """
UNWIND $duplicate_node_uuids AS duplicate
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = list(duplicate_nodes_map.keys())
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=duplicate_node_uuids,
routing_='r',
)
# Remove duplicates that already have the IS_DUPLICATE_OF edge
for record in records:
duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
if duplicate_nodes_map.get(duplicate_tuple):
duplicate_nodes_map.pop(duplicate_tuple)
return list(duplicate_nodes_map.values())