Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Added

- Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j.

- Exposed optional `sample` parameter on `get_schema` and `get_structured_schema` to control APOC sampling for schema discovery.
- Added an optional `id_property_getter` callable parameter in the Qdrant retriever to allow for custom ID retrieval.

## 1.10.1
Expand Down
21 changes: 15 additions & 6 deletions src/neo4j_graphrag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
DISTINCT_VALUE_LIMIT = 10

NODE_PROPERTIES_QUERY = (
"CALL apoc.meta.data() "
"CALL apoc.meta.data({sample: $SAMPLE}) "
"YIELD label, other, elementType, type, property "
"WHERE NOT type = 'RELATIONSHIP' AND elementType = 'node' "
"AND NOT label IN $EXCLUDED_LABELS "
Expand All @@ -38,7 +38,7 @@
)

REL_PROPERTIES_QUERY = (
"CALL apoc.meta.data() "
"CALL apoc.meta.data({sample: $SAMPLE}) "
"YIELD label, other, elementType, type, property "
"WHERE NOT type = 'RELATIONSHIP' AND elementType = 'relationship' "
"AND NOT label in $EXCLUDED_LABELS "
Expand All @@ -47,7 +47,7 @@
)

REL_QUERY = (
"CALL apoc.meta.data() "
"CALL apoc.meta.data({sample: $SAMPLE}) "
"YIELD label, other, elementType, type, property "
"WHERE type = 'RELATIONSHIP' AND elementType = 'node' "
"UNWIND other AS other_node "
Expand Down Expand Up @@ -186,6 +186,7 @@ def get_schema(
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
sample: int = 1000,
) -> str:
"""
Returns the schema of the graph as a string with following format:
Expand All @@ -210,6 +211,8 @@ def get_schema(
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
sample (int): Number of nodes to sample for the apoc.meta.data procedure. Setting sample to -1 will remove sampling.
Defaults to 1000.


Returns:
Expand All @@ -221,6 +224,7 @@ def get_schema(
database=database,
timeout=timeout,
sanitize=sanitize,
sample=sample,
)
return format_schema(structured_schema, is_enhanced)

Expand All @@ -231,6 +235,7 @@ def get_structured_schema(
database: Optional[str] = None,
timeout: Optional[float] = None,
sanitize: bool = False,
sample: int = 1000,
) -> dict[str, Any]:
"""
Returns the structured schema of the graph.
Expand Down Expand Up @@ -280,6 +285,8 @@ def get_structured_schema(
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
sample (int): Number of nodes to sample for the apoc.meta.data procedure. Setting sample to -1 will remove sampling.
Defaults to 1000.

Returns:
dict[str, Any]: the graph schema information in a structured format.
Expand All @@ -291,7 +298,8 @@ def get_structured_schema(
query=NODE_PROPERTIES_QUERY,
params={
"EXCLUDED_LABELS": EXCLUDED_LABELS
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
"SAMPLE": sample,
},
database=database,
timeout=timeout,
Expand All @@ -304,7 +312,7 @@ def get_structured_schema(
for data in query_database(
driver=driver,
query=REL_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_RELS},
params={"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": sample},
database=database,
timeout=timeout,
sanitize=sanitize,
Expand All @@ -318,7 +326,8 @@ def get_structured_schema(
query=REL_QUERY,
params={
"EXCLUDED_LABELS": EXCLUDED_LABELS
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
+ [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
"SAMPLE": sample,
},
database=database,
timeout=timeout,
Expand Down
12 changes: 9 additions & 3 deletions tests/e2e/test_schema_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
@pytest.mark.usefixtures("setup_neo4j_for_schema_query")
def test_cypher_returns_correct_node_properties(driver: Driver) -> None:
node_properties = query_database(
driver, NODE_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]}
driver,
NODE_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000},
)

expected_node_properties = [
Expand All @@ -47,7 +49,9 @@ def test_cypher_returns_correct_node_properties(driver: Driver) -> None:
@pytest.mark.usefixtures("setup_neo4j_for_schema_query")
def test_cypher_returns_correct_relationship_properties(driver: Driver) -> None:
relationships_properties = query_database(
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]}
driver,
REL_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000},
)

expected_relationships_properties = [
Expand All @@ -65,7 +69,9 @@ def test_cypher_returns_correct_relationship_properties(driver: Driver) -> None:
@pytest.mark.usefixtures("setup_neo4j_for_schema_query")
def test_cypher_returns_correct_relationships(driver: Driver) -> None:
relationships = query_database(
driver, REL_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]}
driver,
REL_QUERY,
params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000},
)

expected_relationships = [
Expand Down
11 changes: 8 additions & 3 deletions tests/e2e/test_schema_filters_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_filtering_labels_node_properties(driver: Driver) -> None:
for data in query_database(
driver,
NODE_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_LABELS},
params={"EXCLUDED_LABELS": EXCLUDED_LABELS, "SAMPLE": 1000},
)
]

Expand All @@ -45,7 +45,9 @@ def test_filtering_labels_relationship_properties(driver: Driver) -> None:
relationship_properties = [
data["output"]
for data in query_database(
driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS}
driver,
REL_PROPERTIES_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": 1000},
)
]

Expand All @@ -59,7 +61,10 @@ def test_filtering_labels_relationships(driver: Driver) -> None:
for data in query_database(
driver,
REL_QUERY,
params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]},
params={
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL],
"SAMPLE": 1000,
},
)
]

Expand Down
8 changes: 5 additions & 3 deletions tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None:
assert query_obj.timeout is None
assert kwargs["database_"] is None
assert kwargs["parameters_"] == {
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
"SAMPLE": 1000,
}

args, kwargs = calls[1]
Expand All @@ -106,7 +107,7 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None:
assert query_obj.text == REL_PROPERTIES_QUERY
assert query_obj.timeout is None
assert kwargs["database_"] is None
assert kwargs["parameters_"] == {"EXCLUDED_LABELS": EXCLUDED_RELS}
assert kwargs["parameters_"] == {"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": 1000}

args, kwargs = calls[2]
query_obj = args[0]
Expand All @@ -115,7 +116,8 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None:
assert query_obj.timeout is None
assert kwargs["database_"] is None
assert kwargs["parameters_"] == {
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL]
"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL],
"SAMPLE": 1000,
}

args, kwargs = calls[3]
Expand Down