Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 13 additions & 33 deletions tools/lib/framereader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import subprocess
import json
import logging
from functools import cache
from collections.abc import Iterator
from collections import OrderedDict

Expand All @@ -17,19 +16,6 @@
HEVC_SLICE_P = 1
HEVC_SLICE_I = 2

@cache
def get_hw_accel() -> list[str]:
"""Detect and return the best available ffmpeg hardware acceleration."""
priority = ("videotoolbox", "cuda", "vaapi", "d3d11va")
result = subprocess.run(["ffmpeg", "-hwaccels"], capture_output=True, text=True, timeout=5)
for accel in priority:
if accel in result.stdout.lower():
logger.info(f"HW accelerated video decode found, using ffmpeg's {accel}")
return ["-hwaccel", accel]
logger.warning("no HW accelerated video found with `ffmpeg -hwaccels`. falling back to ffmpeg CPU decode")
return []


class LRUCache:
def __init__(self, capacity: int):
self._cache: OrderedDict = OrderedDict()
Expand All @@ -47,7 +33,6 @@ def __setitem__(self, key, value):
def __contains__(self, key):
return key in self._cache


def assert_hvec(fn: str) -> None:
with FileReader(fn) as f:
header = f.read(4)
Expand All @@ -57,11 +42,11 @@ def assert_hvec(fn: str) -> None:
if 'hevc' not in fn:
raise NotImplementedError(fn)

def decompress_video_data(rawdat, w, h, pix_fmt="rgb24", vid_fmt='hevc') -> np.ndarray:
def decompress_video_data(rawdat, w, h, pix_fmt="rgb24", vid_fmt='hevc', loglevel="info") -> np.ndarray:
threads = os.getenv("FFMPEG_THREADS", "0")
args = ["ffmpeg", "-v", "quiet",
args = ["ffmpeg", "-v", loglevel,
"-threads", threads,
*get_hw_accel(),
"-hwaccel", "auto",
"-c:v", "hevc",
"-vsync", "0",
"-f", vid_fmt,
Expand Down Expand Up @@ -114,15 +99,15 @@ def get_video_index(fn):
'probe': probe
}


class FfmpegDecoder:
def __init__(self, fn: str, index_data: dict|None = None,
pix_fmt: str = "rgb24"):
pix_fmt: str = "rgb24", loglevel="quiet"):
self.fn = fn
self.index, self.prefix, self.w, self.h = get_index_data(fn, index_data)
self.frame_count = len(self.index) - 1 # sentinel row at the end
self.iframes = np.where(self.index[:, 0] == HEVC_SLICE_I)[0]
self.pix_fmt = pix_fmt
self.loglevel = loglevel

def _gop_bounds(self, frame_idx: int):
f_b = frame_idx
Expand All @@ -134,7 +119,7 @@ def _gop_bounds(self, frame_idx: int):
return f_b, f_e, self.index[f_b, 1], self.index[f_e, 1]

def _decode_gop(self, raw: bytes) -> Iterator[np.ndarray]:
yield from decompress_video_data(raw, self.w, self.h, self.pix_fmt)
yield from decompress_video_data(raw, self.w, self.h, self.pix_fmt, loglevel=self.loglevel)

def get_gop_start(self, frame_idx: int):
return self.iframes[np.searchsorted(self.iframes, frame_idx, side="right") - 1]
Expand All @@ -149,25 +134,24 @@ def get_iterator(self, start_fidx: int = 0, end_fidx: int|None = None,
f.seek(off_b)
raw = self.prefix + f.read(off_e - off_b)
# number of frames to discard inside this GOP before the wanted one
for i, frm in enumerate(decompress_video_data(raw, self.w, self.h, self.pix_fmt)):
for i, frm in enumerate(decompress_video_data(raw, self.w, self.h, self.pix_fmt, loglevel=self.loglevel)):
fidx = f_b + i
if fidx >= end_fidx:
return
elif fidx >= start_fidx and (fidx - start_fidx) % frame_skip == 0:
yield fidx, frm
fidx += 1

def FrameIterator(fn: str, index_data: dict|None=None,
pix_fmt: str = "rgb24",
start_fidx:int=0, end_fidx=None, frame_skip:int=1) -> Iterator[np.ndarray]:
dec = FfmpegDecoder(fn, pix_fmt=pix_fmt, index_data=index_data)
def FrameIterator(fn: str, index_data: dict|None=None, pix_fmt: str = "rgb24",
start_fidx:int=0, end_fidx=None, frame_skip:int=1, loglevel="quiet") -> Iterator[np.ndarray]:
dec = FfmpegDecoder(fn, pix_fmt=pix_fmt, index_data=index_data, loglevel=loglevel)
for _, frame in dec.get_iterator(start_fidx=start_fidx, end_fidx=end_fidx, frame_skip=frame_skip):
yield frame

class FrameReader:
def __init__(self, fn: str, index_data: dict|None = None,
cache_size: int = 30, pix_fmt: str = "rgb24"):
self.decoder = FfmpegDecoder(fn, index_data, pix_fmt)
def __init__(self, fn: str, index_data: dict|None = None, cache_size: int = 30,
pix_fmt: str = "rgb24", loglevel="quiet"):
self.decoder = FfmpegDecoder(fn, index_data, pix_fmt, loglevel=loglevel)
self.iframes = self.decoder.iframes
self._cache: LRUCache = LRUCache(cache_size)
self.w, self.h, self.frame_count, = self.decoder.w, self.decoder.h, self.decoder.frame_count
Expand All @@ -187,7 +171,3 @@ def get(self, fidx:int):
self.fidx, frame = next(self.it)
self._cache[self.fidx] = frame
return self._cache[fidx]


if __name__ == "__main__":
get_hw_accel()
Loading