diff --git a/CHANGELOG.md b/CHANGELOG.md index 7175ec6a..cdcf2621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ ### Added - Added an optional `node_label_neo4j` parameter in the external retrievers to speed up the search query in Neo4j. + +- Exposed optional `sample` parameter on `get_schema` and `get_structured_schema` to control APOC sampling for schema discovery. - Added an optional `id_property_getter` callable parameter in the Qdrant retriever to allow for custom ID retrieval. ## 1.10.1 diff --git a/src/neo4j_graphrag/schema.py b/src/neo4j_graphrag/schema.py index 40292067..5f299e11 100644 --- a/src/neo4j_graphrag/schema.py +++ b/src/neo4j_graphrag/schema.py @@ -29,7 +29,7 @@ DISTINCT_VALUE_LIMIT = 10 NODE_PROPERTIES_QUERY = ( - "CALL apoc.meta.data() " + "CALL apoc.meta.data({sample: $SAMPLE}) " "YIELD label, other, elementType, type, property " "WHERE NOT type = 'RELATIONSHIP' AND elementType = 'node' " "AND NOT label IN $EXCLUDED_LABELS " @@ -38,7 +38,7 @@ ) REL_PROPERTIES_QUERY = ( - "CALL apoc.meta.data() " + "CALL apoc.meta.data({sample: $SAMPLE}) " "YIELD label, other, elementType, type, property " "WHERE NOT type = 'RELATIONSHIP' AND elementType = 'relationship' " "AND NOT label in $EXCLUDED_LABELS " @@ -47,7 +47,7 @@ ) REL_QUERY = ( - "CALL apoc.meta.data() " + "CALL apoc.meta.data({sample: $SAMPLE}) " "YIELD label, other, elementType, type, property " "WHERE type = 'RELATIONSHIP' AND elementType = 'node' " "UNWIND other AS other_node " @@ -186,6 +186,7 @@ def get_schema( database: Optional[str] = None, timeout: Optional[float] = None, sanitize: bool = False, + sample: int = 1000, ) -> str: """ Returns the schema of the graph as a string with following format: @@ -210,6 +211,8 @@ def get_schema( sanitize (bool): A flag to indicate whether to remove lists with more than 128 elements from results. Useful for removing embedding-like properties from database responses. Default is False. + sample (int): Number of nodes to sample for the apoc.meta.data procedure. Setting sample to -1 will remove sampling. + Defaults to 1000. Returns: @@ -221,6 +224,7 @@ def get_schema( database=database, timeout=timeout, sanitize=sanitize, + sample=sample, ) return format_schema(structured_schema, is_enhanced) @@ -231,6 +235,7 @@ def get_structured_schema( database: Optional[str] = None, timeout: Optional[float] = None, sanitize: bool = False, + sample: int = 1000, ) -> dict[str, Any]: """ Returns the structured schema of the graph. @@ -280,6 +285,8 @@ def get_structured_schema( sanitize (bool): A flag to indicate whether to remove lists with more than 128 elements from results. Useful for removing embedding-like properties from database responses. Default is False. + sample (int): Number of nodes to sample for the apoc.meta.data procedure. Setting sample to -1 will remove sampling. + Defaults to 1000. Returns: dict[str, Any]: the graph schema information in a structured format. @@ -291,7 +298,8 @@ def get_structured_schema( query=NODE_PROPERTIES_QUERY, params={ "EXCLUDED_LABELS": EXCLUDED_LABELS - + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL] + + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL], + "SAMPLE": sample, }, database=database, timeout=timeout, @@ -304,7 +312,7 @@ def get_structured_schema( for data in query_database( driver=driver, query=REL_PROPERTIES_QUERY, - params={"EXCLUDED_LABELS": EXCLUDED_RELS}, + params={"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": sample}, database=database, timeout=timeout, sanitize=sanitize, @@ -318,7 +326,8 @@ def get_structured_schema( query=REL_QUERY, params={ "EXCLUDED_LABELS": EXCLUDED_LABELS - + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL] + + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL], + "SAMPLE": sample, }, database=database, timeout=timeout, diff --git a/tests/e2e/test_schema_e2e.py b/tests/e2e/test_schema_e2e.py index 5226e985..e362ed43 100644 --- a/tests/e2e/test_schema_e2e.py +++ b/tests/e2e/test_schema_e2e.py @@ -29,7 +29,9 @@ @pytest.mark.usefixtures("setup_neo4j_for_schema_query") def test_cypher_returns_correct_node_properties(driver: Driver) -> None: node_properties = query_database( - driver, NODE_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} + driver, + NODE_PROPERTIES_QUERY, + params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000}, ) expected_node_properties = [ @@ -47,7 +49,9 @@ def test_cypher_returns_correct_node_properties(driver: Driver) -> None: @pytest.mark.usefixtures("setup_neo4j_for_schema_query") def test_cypher_returns_correct_relationship_properties(driver: Driver) -> None: relationships_properties = query_database( - driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} + driver, + REL_PROPERTIES_QUERY, + params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000}, ) expected_relationships_properties = [ @@ -65,7 +69,9 @@ def test_cypher_returns_correct_relationship_properties(driver: Driver) -> None: @pytest.mark.usefixtures("setup_neo4j_for_schema_query") def test_cypher_returns_correct_relationships(driver: Driver) -> None: relationships = query_database( - driver, REL_QUERY, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} + driver, + REL_QUERY, + params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL], "SAMPLE": 1000}, ) expected_relationships = [ diff --git a/tests/e2e/test_schema_filters_e2e.py b/tests/e2e/test_schema_filters_e2e.py index 5e79cdca..5d11d453 100644 --- a/tests/e2e/test_schema_filters_e2e.py +++ b/tests/e2e/test_schema_filters_e2e.py @@ -33,7 +33,7 @@ def test_filtering_labels_node_properties(driver: Driver) -> None: for data in query_database( driver, NODE_PROPERTIES_QUERY, - params={"EXCLUDED_LABELS": EXCLUDED_LABELS}, + params={"EXCLUDED_LABELS": EXCLUDED_LABELS, "SAMPLE": 1000}, ) ] @@ -45,7 +45,9 @@ def test_filtering_labels_relationship_properties(driver: Driver) -> None: relationship_properties = [ data["output"] for data in query_database( - driver, REL_PROPERTIES_QUERY, params={"EXCLUDED_LABELS": EXCLUDED_RELS} + driver, + REL_PROPERTIES_QUERY, + params={"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": 1000}, ) ] @@ -59,7 +61,10 @@ def test_filtering_labels_relationships(driver: Driver) -> None: for data in query_database( driver, REL_QUERY, - params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, + params={ + "EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL], + "SAMPLE": 1000, + }, ) ] diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 148be55d..656b2d7b 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -97,7 +97,8 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None: assert query_obj.timeout is None assert kwargs["database_"] is None assert kwargs["parameters_"] == { - "EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL] + "EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL], + "SAMPLE": 1000, } args, kwargs = calls[1] @@ -106,7 +107,7 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None: assert query_obj.text == REL_PROPERTIES_QUERY assert query_obj.timeout is None assert kwargs["database_"] is None - assert kwargs["parameters_"] == {"EXCLUDED_LABELS": EXCLUDED_RELS} + assert kwargs["parameters_"] == {"EXCLUDED_LABELS": EXCLUDED_RELS, "SAMPLE": 1000} args, kwargs = calls[2] query_obj = args[0] @@ -115,7 +116,8 @@ def test_get_structured_schema_happy_path(driver: MagicMock) -> None: assert query_obj.timeout is None assert kwargs["database_"] is None assert kwargs["parameters_"] == { - "EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL] + "EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL, BASE_KG_BUILDER_LABEL], + "SAMPLE": 1000, } args, kwargs = calls[3]