From 94d1a2f70a3e1b74ba642e289178d90aaaf0d0f9 Mon Sep 17 00:00:00 2001 From: Chengxin Wang Date: Fri, 19 Dec 2025 14:34:06 +0000 Subject: [PATCH] feat: migrate to beancount v3 --- pyproject.toml | 5 ++-- src/beanbot/importer/alipay.py | 6 ++--- src/beanbot/importer/bank_of_china.py | 12 +++++----- src/beanbot/importer/citic.py | 18 +++++++-------- src/beanbot/importer/csv_importer.py | 9 ++++---- src/beanbot/importer/deutsche_bank.py | 16 ++++++------- src/beanbot/importer/dkb.py | 14 ++++++------ src/beanbot/importer/hooks.py | 12 ++++++---- tests/data/import.config | 18 +++++++++++---- tests/importer/test_alipay.py | 33 ++++++++++++++++----------- 10 files changed, 81 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67c2c79..4550c1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,14 +8,15 @@ authors = [ readme = "README.md" requires-python = ">=3.12" dependencies = [ - "beancount>=2.3.6,<3.0.0", + "beancount>=3.0.0,<4.0.0", + "beangulp>=0.2.0", "scikit-learn>=1.5.0", "numpy>=1.26.4", "titlecase>=2.4.1", "dateparser>=1.2.0", "pandas>=2.2.2", "xlrd>=2.0.1", - "fava>=1.27.3", + "fava>=1.29", "regex>=2024.5.15", "pylint>=3.2.3", "yapf>=0.40.2", diff --git a/src/beanbot/importer/alipay.py b/src/beanbot/importer/alipay.py index fa56a45..05fedf1 100644 --- a/src/beanbot/importer/alipay.py +++ b/src/beanbot/importer/alipay.py @@ -32,9 +32,9 @@ def __init__( ) self._commission_account = commission_account - def identify(self, file) -> Match[str] | None: - print(file.name) - return re.match(r"^.*alipay/transactions/((?!archive).)*/.*\.csv$", file.name) + def identify(self, filepath: str) -> Match[str] | None: + print(filepath) + return re.match(r"^.*alipay/transactions/((?!archive).)*/.*\.csv$", filepath) def _parse_header(self, header_lines: List) -> List[Directive]: """ diff --git a/src/beanbot/importer/bank_of_china.py b/src/beanbot/importer/bank_of_china.py index 270d6e8..e927446 100644 --- a/src/beanbot/importer/bank_of_china.py +++ b/src/beanbot/importer/bank_of_china.py @@ -4,7 +4,7 @@ from beancount.core import data, flags from beancount.core.amount import Amount from beancount.core.number import D -from beancount.ingest import importer +from beangulp import importer from dateutil.parser import parse @@ -29,13 +29,13 @@ def __init__(self, account, currency="CNY"): self.account = account self.currency = currency - def identify(self, f): - return re.match(r"^.*boc/transactions/((?!archive).)*/.*\.csv$", f.name) + def identify(self, filepath: str): + return re.match(r"^.*boc/transactions/((?!archive).)*/.*\.csv$", filepath) - def extract(self, f, existing_entries=None): + def extract(self, filepath: str, existing_entries=None): entries = [] - with open(f.name, encoding="UTF-16-LE") as csvfile: + with open(filepath, encoding="UTF-16-LE") as csvfile: for index, row in enumerate(csv.DictReader(csvfile, delimiter="\t")): trans_date = parse(row["\ufeff交易日期"], yearfirst=True).date() trans_payee = row["对方账户名称"] @@ -51,7 +51,7 @@ def extract(self, f, existing_entries=None): else: trans_amount = Amount(-D(outgoing_amount), trans_currency) - meta = data.new_metadata(f.name, index) + meta = data.new_metadata(filepath, index) txn = data.Transaction( meta=meta, diff --git a/src/beanbot/importer/citic.py b/src/beanbot/importer/citic.py index cc5caf8..ede3ace 100644 --- a/src/beanbot/importer/citic.py +++ b/src/beanbot/importer/citic.py @@ -5,7 +5,7 @@ from beancount.core import data, flags from beancount.core.amount import Amount from beancount.core.number import D -from beancount.ingest import importer +from beangulp import importer def get_currency(currency): @@ -34,7 +34,7 @@ def __init__(self, account, card_type: str, lastfour: str | list[str]): lastfour = [lastfour] self.lastfour = lastfour - def identify(self, file): + def identify(self, filepath: str): """ assert raw transaction records are stored as: .../citic/transactions/[card_type]/YYYYMM.xls @@ -44,10 +44,10 @@ def identify(self, file): # return re.match('.*\.xls', os.path.basename(f.name)) # return re.match(f"^.*citic/transactions/((?!archive).)*/.*\.xls$", file.name) return re.match( - r"^.*citic/transactions/" + self.card_type + r"/.*\.xls$", file.name + r"^.*citic/transactions/" + self.card_type + r"/.*\.xls$", filepath ) - def extract(self, file, existing_entries=None): + def extract(self, filepath: str, existing_entries=None): """ format example (6393): 交易日期 入账日期 交易描述 卡末四位 交易币种 结算币种 交易金额 结算金额 @@ -59,12 +59,12 @@ def extract(self, file, existing_entries=None): """ entries = [] try: - dataframe = pandas.read_excel(io=file.name, sheet_name="本期账单明细") + dataframe = pandas.read_excel(io=filepath, sheet_name="本期账单明细") except ValueError: dataframe = pandas.read_excel( - io=file.name, sheet_name="本期账单明细(人民币)" + io=filepath, sheet_name="本期账单明细(人民币)" ) - card_type = file.name.rsplit("/")[-2] + card_type = filepath.rsplit("/")[-2] assert ( card_type == self.card_type ), f"Expect card type {self.card_type}, got {card_type}" @@ -108,7 +108,7 @@ def extract(self, file, existing_entries=None): if row_data.iloc[1].isdigit(): meta = data.new_metadata( - file.name, + filepath, index, { "booked_on": dateparser.parse( @@ -118,7 +118,7 @@ def extract(self, file, existing_entries=None): ) else: meta = data.new_metadata( - file.name, + filepath, index, {"booked_on": dateparser.parse(row_data.iloc[1]).date()}, ) diff --git a/src/beanbot/importer/csv_importer.py b/src/beanbot/importer/csv_importer.py index 655c095..aea6426 100644 --- a/src/beanbot/importer/csv_importer.py +++ b/src/beanbot/importer/csv_importer.py @@ -2,9 +2,8 @@ from re import Match from typing import List, Optional -from beancount.ingest import importer +from beangulp import importer from beancount.core.data import Directive -from beancount.ingest.cache import _FileMemo class CSVImporter(importer.ImporterProtocol): @@ -86,7 +85,7 @@ def _remove_whitespaces( return lines def extract( - self, file: _FileMemo, existing_entries: Optional[List[Directive]] = None + self, filepath: str, existing_entries: Optional[List[Directive]] = None ) -> List[Directive]: entries = [] @@ -96,7 +95,7 @@ def extract( self._file_meta = {} - with open(file.name, encoding=self._encoding) as csvfile: + with open(filepath, encoding=self._encoding) as csvfile: if self._header_lines: header_lines = [next(csvfile) for _ in range(self._header_lines)] assert ( @@ -126,7 +125,7 @@ def extract( csv.DictReader(body_lines, **self._csv_reader_kwargs) ): body_entries.extend( - self._parse_row_impl(row, file.name, index + self._header_lines + 1) + self._parse_row_impl(row, filepath, index + self._header_lines + 1) ) entries = sorted( diff --git a/src/beanbot/importer/deutsche_bank.py b/src/beanbot/importer/deutsche_bank.py index 0eed9b2..fddca56 100644 --- a/src/beanbot/importer/deutsche_bank.py +++ b/src/beanbot/importer/deutsche_bank.py @@ -4,7 +4,7 @@ from beancount.core import amount, data, flags from beancount.core.number import D -from beancount.ingest import importer +from beangulp import importer from dateutil.parser import parse @@ -12,21 +12,21 @@ class Importer(importer.ImporterProtocol): def __init__(self, account): self._account = account - def identify(self, f): + def identify(self, filepath: str): return re.match( - r"^.*deutsche_bank/transactions/((?!archive).)*/.*\.csv$", f.name + r"^.*deutsche_bank/transactions/((?!archive).)*/.*\.csv$", filepath ) - def extract(self, f, existing_entries=None): + def extract(self, filepath: str, existing_entries=None): entries = [] - with open(f.name, encoding="latin-1") as csvfile: + with open(filepath, encoding="latin-1") as csvfile: for index, row in enumerate(csv.DictReader(csvfile, delimiter=";")): trans_date = parse(row["Datum"], dayfirst=True).date() trans_payee = row["Auftraggeber / Empfänger"] trans_narration = row["Verwendungszweck"] trans_amount = row["Betrag"].replace(",", ".") - trans_meta = data.new_metadata(f.name, index) + trans_meta = data.new_metadata(filepath, index) # trans_meta['__source__'] = ';'.join(list(row.values())) txn = data.Transaction( @@ -61,6 +61,6 @@ def file_account(self, _): def file_name(self, _): return "Deutsche_Bank_Transaktionen" - def file_date(self, file): - date_str = re.search(r"\d{2}\-\d{2}\-\d{4}", str(file)).group(0) + def file_date(self, filepath: str): + date_str = re.search(r"\d{2}\-\d{2}\-\d{4}", filepath).group(0) return datetime.strptime(date_str, "%d-%m-%Y").date() diff --git a/src/beanbot/importer/dkb.py b/src/beanbot/importer/dkb.py index d7d9d3f..df5b2f0 100644 --- a/src/beanbot/importer/dkb.py +++ b/src/beanbot/importer/dkb.py @@ -6,7 +6,7 @@ from beancount.core import data, flags from beancount.core.amount import Amount from beancount.core.number import D -from beancount.ingest import importer +from beangulp import importer import dateutil import parse @@ -21,14 +21,14 @@ def __init__(self, account, lastfour): self._account = account self._lastfour = lastfour - def identify(self, f): + def identify(self, filepath: str): # Match based on filename structure, assuming last four digits are part of the path - return re.match(rf"^.*dkb/transactions/{self._lastfour}/.*\.csv$", f.name) + return re.match(rf"^.*dkb/transactions/{self._lastfour}/.*\.csv$", filepath) - def extract(self, f, existing_entries=None): + def extract(self, filepath: str, existing_entries=None): entries = [] - with open(f.name, encoding="UTF-8") as csvfile: + with open(filepath, encoding="UTF-8") as csvfile: # Read the new 4-line header csv_header = [next(csvfile).strip().replace('"', "") for _ in range(4)] @@ -71,7 +71,7 @@ def extract(self, f, existing_entries=None): balance_amount = Amount(D(bal_val), bal_currency) balance = data.Balance( - meta=data.new_metadata(f.name, 3), # Line number in header + meta=data.new_metadata(filepath, 3), # Line number in header date=balance_date + timedelta(days=1), # Beancount balance is start of day after account=self._account, @@ -103,7 +103,7 @@ def extract(self, f, existing_entries=None): customer_reference = string_cleaning(row["Kundenreferenz"].strip()) trans_meta = data.new_metadata( - f.name, index + 6 + filepath, index + 6 ) # Adjust line number accounting for header # Note: For metadata keys must begin with a lowercase character if sepa_creditor_id != "": diff --git a/src/beanbot/importer/hooks.py b/src/beanbot/importer/hooks.py index 84e2e76..4a6a2dd 100644 --- a/src/beanbot/importer/hooks.py +++ b/src/beanbot/importer/hooks.py @@ -4,7 +4,7 @@ from beancount.loader import load_file from beancount.core.data import Entries -from beancount.ingest.importer import ImporterProtocol +from beangulp.importer import ImporterProtocol from beanbot.classifier.meta_transaction_classifier import MetaTransactionClassifier from beancount.core import data from beanbot.ops.filter import TransactionFilter, NotTransactionFilter @@ -51,12 +51,16 @@ def apply_hooks( unpatched_extract = importer.extract @wraps(unpatched_extract) - def patched_extract_method(file, existing_entries=None): + def patched_extract_method(filepath: str, existing_entries=None): logger.debug("Calling the importer's extract method.") - imported_entries = unpatched_extract(file, existing_entries=existing_entries) + imported_entries = unpatched_extract( + filepath, existing_entries=existing_entries + ) for hook in hooks: - imported_entries = hook(importer, file, imported_entries, existing_entries) + imported_entries = hook( + importer, filepath, imported_entries, existing_entries + ) return imported_entries diff --git a/tests/data/import.config b/tests/data/import.config index 6cfe0f8..b83a638 100644 --- a/tests/data/import.config +++ b/tests/data/import.config @@ -5,9 +5,17 @@ from beanbot.importer import deutsche_bank, citic, bank_of_china, dkb, alipay CONFIG = [ - apply_hooks(deutsche_bank.Importer('Assets:Checking:DeutscheBank'), [BeanBotPredictionHook()]), - apply_hooks(citic.Importer('Liabilities:Credit:Citic:Visa', 'visa', '0000'), [BeanBotPredictionHook()]), - apply_hooks(bank_of_china.Importer('Assets:Checking:BankOfChina'), [BeanBotPredictionHook()]), - apply_hooks(dkb.Importer('Assets:Checking:DKB', '0000'), [BeanBotPredictionHook()]), - alipay.Importer('Assets:Checking:Alipay', 'Expenses:Financial:Commissions'), + apply_hooks( + deutsche_bank.Importer("Assets:Checking:DeutscheBank"), + [BeanBotPredictionHook()], + ), + apply_hooks( + citic.Importer("Liabilities:Credit:Citic:Visa", "visa", "0000"), + [BeanBotPredictionHook()], + ), + apply_hooks( + bank_of_china.Importer("Assets:Checking:BankOfChina"), [BeanBotPredictionHook()] + ), + apply_hooks(dkb.Importer("Assets:Checking:DKB", "0000"), [BeanBotPredictionHook()]), + alipay.Importer("Assets:Checking:Alipay", "Expenses:Financial:Commissions"), ] diff --git a/tests/importer/test_alipay.py b/tests/importer/test_alipay.py index 968cfed..5efeb2f 100644 --- a/tests/importer/test_alipay.py +++ b/tests/importer/test_alipay.py @@ -1,15 +1,24 @@ from pathlib import Path -import subprocess +from beancount.parser import printer -def run_cli_command(): - command = [ - "bean-extract", - "tests/data/import.config", - "tests/data/raw/alipay/transactions/foo@bar.com/alipay_record_20240101_0000_1.csv", - ] - result = subprocess.run(command, capture_output=True, text=True) - return result.stdout +from beanbot.importer.alipay import Importer + + +def run_extraction(): + """Extract entries from the test alipay CSV file using the importer directly.""" + importer = Importer("Assets:Checking:Alipay", "Expenses:Financial:Commissions") + filename = "tests/data/raw/alipay/transactions/foo@bar.com/alipay_record_20240101_0000_1.csv" + + # Extract entries directly from the importer + entries = importer.extract(filename, existing_entries=[]) + + # Format entries as beancount text + output_lines = [] + for entry in entries: + output_lines.append(printer.format_entry(entry)) + + return "\n".join(output_lines) def read_file_content(filepath): @@ -23,9 +32,7 @@ def compare_output(expected_file, actual_output): expected_content = read_file_content(expected_file) actual_content = actual_output.splitlines() - # Ignore the first 4 rows - actual_content = actual_content[4:] - + # No longer need to skip rows since we're not using bean-extract CLI expected_content = [line.strip() for line in expected_content if line.strip()] actual_content = [line.strip() for line in actual_content if line.strip()] assert ( @@ -34,5 +41,5 @@ def compare_output(expected_file, actual_output): def test_cli_command_output(): - actual_output = run_cli_command() + actual_output = run_extraction() compare_output("tests/data/expected/importer/alipay.bean", actual_output)