Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/publish-book.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: publish-book

on:
push:
branches:
- main

jobs:
deploy-book:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
python -m pip install .[docs]
pip install jupyter-book sphinxcontrib-mermaid

- name: Build the book
run: |
jupyter-book build .

- name: GitHub Pages action
uses: peaceiris/actions-gh-pages@v3.9.3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./_build/html
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ __pycache__
__pycache__/
*.py[cod]
*$py.class
notebooks/.ipynb_checkpoints/

# Binary files
*.jpg
Expand All @@ -23,6 +24,7 @@ __pycache__/
# Distribution / packaging
.Python
build/
_build/
develop-eggs/
dist/
downloads/
Expand Down
28 changes: 28 additions & 0 deletions _config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
title: AmadeusGPT
author: MLAI
logo: docs/logo.png
only_build_toc_files: true

sphinx:
config:
autodoc_mock_imports: list #["wx"]
extra_extensions:
- numpydoc

execute:
execute_notebooks: "off"

html:
extra_navbar: ""
use_issues_button: true
use_repository_button: true
extra_footer: |
<div>Powered by <a href="https://jupyterbook.org/">Jupyter Book</a>.</div>

repository:
url: https://github.com/AdaptiveMotorControlLab/AmadeusGPT
path_to_book: main
branch: main

launch_buttons:
colab_url: "https://colab.research.google.com/github.com/AdaptiveMotorControlLab/AmadeusGPT/examples/yourdemo.ipynb"
11 changes: 11 additions & 0 deletions _toc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
format: jb-book
root: README
parts:
- caption: Using AmadeusGPT
chapters:
- file: notebooks/EPM_demo
- file: notebooks/Horse_demo
- file: notebooks/MABe_demo
- file: notebooks/MausHaus_demo
- file: notebooks/Use_Task_Program
- file: notebooks/YourData
4 changes: 4 additions & 0 deletions amadeusgpt/analysis_objects/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def _superanimal_inference(
):
import deeplabcut

# Patch for PyTorch 2.6 weights_only issue
from amadeusgpt.utils import patch_pytorch_weights_only
patch_pytorch_weights_only()

progress_obj = st.progress(0)
deeplabcut.video_inference_superanimal(
[video_file_path],
Expand Down
24 changes: 21 additions & 3 deletions amadeusgpt/managers/animal_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(self, identifier: Identifier):
self.full_keypoint_names = []
self.superanimal_predicted_video = None
self.superanimal_name = None
self.model_name = None
self.detector_name = None
self.init_pose()

def configure_animal_from_meta(self, meta_info):
Expand All @@ -106,11 +108,17 @@ def configure_animal_from_meta(self, meta_info):
self.max_individuals = int(meta_info["individuals"])
species = meta_info["species"]
if species == "topview_mouse":
self.superanimal_name = "superanimal_topviewmouse_hrnetw32"
self.superanimal_name = "superanimal_topviewmouse"
self.model_name = "hrnet_w32"
self.detector_name = "fasterrcnn_resnet50_fpn_v2"
elif species == "sideview_quadruped":
self.superanimal_name = "superanimal_quadruped_hrnetw32"
self.superanimal_name = "superanimal_quadruped"
self.model_name = "hrnet_w32"
self.detector_name = "fasterrcnn_resnet50_fpn_v2"
else:
self.superanimal_name = None
self.model_name = None
self.detector_name = None

def init_pose(self):

Expand Down Expand Up @@ -304,20 +312,30 @@ def get_keypoints(self) -> ndarray:
from deeplabcut.modelzoo.video_inference import \
video_inference_superanimal

# Patch for PyTorch 2.6+ weights_only issue
from amadeusgpt.utils import patch_pytorch_weights_only
patch_pytorch_weights_only()

video_suffix = Path(self.video_file_path).suffix

self.keypoint_file_path = self.video_file_path.replace(
video_suffix, "_" + self.superanimal_name + ".h5"
video_suffix, f"_superanimal_{self.superanimal_name.split('_', 1)[1]}_{self.detector_name}_{self.model_name}.h5"
)
self.superanimal_predicted_video = self.keypoint_file_path.replace(
".h5", "_labeled.mp4"
)

if not os.path.exists(self.keypoint_file_path):
print(f"going to inference video with {self.superanimal_name}")
if self.model_name is None:
raise ValueError("Model name not set. Please call configure_animal_from_meta first.")
if self.detector_name is None:
raise ValueError("Detector name not set. Please call configure_animal_from_meta first.")
video_inference_superanimal(
videos=[self.video_file_path],
superanimal_name=self.superanimal_name,
model_name=self.model_name,
detector_name=self.detector_name,
max_individuals=self.max_individuals,
video_adapt=False,
)
Expand Down
201 changes: 201 additions & 0 deletions amadeusgpt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import ast
import inspect
import sys
import time
import traceback
from collections import defaultdict
import textwrap
import numpy as np
from amadeusgpt.analysis_objects.event import Event
from amadeusgpt.logger import AmadeusLogger
from IPython.display import Markdown, Video, display, HTML

def filter_kwargs_for_function(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}

def timer_decorator(func):
def wrapper(*args, **kwargs):
start_time = time.time() # before calling the function
result = func(*args, **kwargs) # call the function
end_time = time.time() # after calling the function
AmadeusLogger.debug(
f"The function {func.__name__} took {end_time - start_time} seconds to execute."
)
print(
f"The function {func.__name__} took {end_time - start_time} seconds to execute."
)
return result
return wrapper

def parse_error_message_from_python():
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback_str = "".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
return traceback_str

def validate_openai_api_key(key):
import openai
openai.api_key = key
try:
openai.models.list()
return True
except openai.AuthenticationError:
return False

def flatten_tuple(t):
"""
Used to handle function returns
"""
flattened = []
for item in t:
if isinstance(item, tuple):
flattened.extend(flatten_tuple(item))
else:
flattened.append(item)
return tuple(flattened)

def func2json(func):
if isinstance(func, str):
func_str = textwrap.dedent(func)
parsed = ast.parse(func_str)
func_def = parsed.body[0]
func_name = func_def.name
docstring = ast.get_docstring(func_def)
if (
func_def.body
and isinstance(func_def.body[0], ast.Expr)
and isinstance(func_def.body[0].value, (ast.Str, ast.Constant))
):
func_def.body.pop(0)
func_def.decorator_list = []
if hasattr(ast, "unparse"):
source_without_docstring_or_decorators = ast.unparse(func_def)
else:
source_without_docstring_or_decorators = None
return_annotation = "No return annotation"
if func_def.returns:
return_annotation = ast.unparse(func_def.returns)
json_obj = {
"name": func_name,
"inputs": "",
"source_code": source_without_docstring_or_decorators,
"docstring": docstring,
"return": return_annotation,
}
return json_obj
else:
sig = inspect.signature(func)
inputs = {name: str(param.annotation) for name, param in sig.parameters.items()}
docstring = inspect.getdoc(func)
if docstring:
docstring = textwrap.dedent(docstring)
full_source = inspect.getsource(func)
parsed = ast.parse(textwrap.dedent(full_source))
func_def = parsed.body[0]
if (
func_def.body
and isinstance(func_def.body[0], ast.Expr)
and isinstance(func_def.body[0].value, (ast.Str, ast.Constant))
):
func_def.body.pop(0)
func_def.decorator_list = []
if hasattr(ast, "unparse"):
source_without_docstring_or_decorators = ast.unparse(func_def)
else:
source_without_docstring_or_decorators = None
json_obj = {
"name": func.__name__,
"inputs": inputs,
"source_code": textwrap.dedent(source_without_docstring_or_decorators),
"docstring": docstring,
"return": str(sig.return_annotation),
}
return json_obj

class QA_Message:
def __init__(self, query: str, video_file_paths: list[str]):
self.query = query
self.video_file_paths = video_file_paths
self.code = None
self.chain_of_thought = None
self.error_message = defaultdict(list)
self.plots = defaultdict(list)
self.out_videos = defaultdict(list)
self.pose_video = defaultdict(list)
self.function_rets = defaultdict(list)
self.meta_info = {}
def get_masks(self) -> dict[str, np.ndarray]:
ret = {}
function_rets = self.function_rets
for video_path, rets in function_rets.items():
if isinstance(rets, list) and len(rets) > 0 and isinstance(rets[0], Event):
events = rets
masks = []
for event in events:
masks.append(event.generate_mask())
ret[video_path] = np.array(masks)
else:
ret[video_path] = None
return ret
def serialize_qa_message(self):
return {
"query": self.query,
"video_file_paths": self.video_file_paths,
"code": self.code,
"chain_of_thought": self.chain_of_thought,
"error_message": self.error_message,
"plots": None,
"out_videos": self.out_videos,
"pose_video": self.pose_video,
"function_rets": self.function_rets,
"meta_info": self.meta_info,
}
def create_qa_message(query: str, video_file_paths: list[str]) -> QA_Message:
return QA_Message(query, video_file_paths)
def parse_result(amadeus, qa_message, use_ipython=True, skip_code_execution=False):
if use_ipython:
display(Markdown(qa_message.chain_of_thought))
else:
print(qa_message.chain_of_thought)
sandbox = amadeus.sandbox
if not skip_code_execution:
qa_message = sandbox.code_execution(qa_message)
qa_message = sandbox.render_qa_message(qa_message)
if len(qa_message.out_videos) > 0:
print(f"videos generated to {qa_message.out_videos}")
print(
"Open it with media player if it does not properly display in the notebook"
)
if use_ipython:
if len(qa_message.out_videos) > 0:
for identifier, event_videos in qa_message.out_videos.items():
for event_video in event_videos:
display(Video(event_video, embed=True))
if use_ipython:
from matplotlib.animation import FuncAnimation
if len(qa_message.function_rets) > 0:
for identifier, rets in qa_message.function_rets.items():
if not isinstance(rets, (tuple, list)):
rets = [rets]
for ret in rets:
if isinstance(ret, FuncAnimation):
display(HTML(ret.to_jshtml()))
else:
display(Markdown(str(qa_message.function_rets[identifier])))
return qa_message

def patch_pytorch_weights_only():
"""
Patch for PyTorch 2.6 weights_only issue with DeepLabCut SuperAnimal models.
This adds safe globals to allow loading of ruamel.yaml.scalarfloat.ScalarFloat objects.
Only applies the patch if torch.serialization.add_safe_globals exists (PyTorch >=2.6).
"""
try:
import torch
from ruamel.yaml.scalarfloat import ScalarFloat
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([ScalarFloat])
except ImportError:
pass # If ruamel.yaml is not available, continue without the patch
Binary file added docs/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading