From 471c53e2ba7f5dce4aaaf7ffd3b60af55e1a1da6 Mon Sep 17 00:00:00 2001 From: Charlie Summers Date: Tue, 4 Jan 2022 20:44:31 -0800 Subject: [PATCH] Dockerize watts --- .gitignore | 3 ++- Dockerfile | 6 ++++++ Makefile | 10 ++++++++++ README.md | 10 ++++------ logs/.gitkeep | 0 poet_distributed.py | 8 ++++---- sample_args/args.yaml | 4 ++-- setup.py | 2 +- snapshots/.gitkeep | 0 watts/serializer/POETManagerSerializer.py | 4 ++-- watts/solvers/SingleAgentSolver.py | 2 +- watts/utils/gym_wrappers.py | 2 +- 12 files changed, 33 insertions(+), 18 deletions(-) create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 logs/.gitkeep create mode 100644 snapshots/.gitkeep diff --git a/.gitignore b/.gitignore index 5bc6811..762bed6 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ */__pychache__/* *.pyc -snapshot.pkl +snapshots/* +logs/* *.egg-info diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e8dce22 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM rayproject/ray-ml:1.6.0 + +# Copy watts/* into into /home/ray/ +COPY . . + +RUN pip install -e . diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..dd98727 --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +build: + docker build -t watts . + +run: + docker run -i -t \ + --shm-size=8G \ + -p 8265:8265 \ + --mount type=bind,source=`pwd`/logs,target=/home/ray/logs \ + --mount type=bind,source=`pwd`/snapshots,target=/home/ray/snapshots \ + watts python poet_distributed.py diff --git a/README.md b/README.md index e5c9624..e120e4c 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ -# enigma +# watts PINSKY v2 This is a rewrite and extension of the UntouchableThunder repo. -**Furthermore, we aim for Enigma to generically tackle the +**Furthermore, we aim for Watts to generically tackle the problem of simultaneous learning with generators and solvers.** In PINSKY (v1.0), I manually created futures and collected answers after each distributed call. @@ -15,10 +15,8 @@ by a user, and to cleanly scale to arbitrary compute. Installation: - * conda create -n NAME python=3.7.10 - * conda activate NAME - * Install pytorch according to your system and environment from here: https://pytorch.org/get-started/locally/ - * At the root of this project, run: `pip install -e .` + * make + * make run If you see stuff like: ``` diff --git a/logs/.gitkeep b/logs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/poet_distributed.py b/poet_distributed.py index 97c4913..e7c2b8d 100644 --- a/poet_distributed.py +++ b/poet_distributed.py @@ -60,9 +60,9 @@ def save_obj(obj, folder, name): network_factory = NetworkFactory(registry.network_name, registry.get_nn_build_info) - generator = StaticGenerator(args.initial_level_string) - #generator = EvolutionaryGenerator(args.initial_level_string, - # file_args=registry.get_generator_config) + # generator = StaticGenerator(args.initial_level_string) + generator = EvolutionaryGenerator(args.initial_level_string, + file_args=registry.get_generator_config) if args.use_snapshot: manager = POETManagerSerializer.deserialize() @@ -96,7 +96,7 @@ def save_obj(obj, folder, name): _release(manager._evolution_strategy._replacement_strategy.archive_history, manager.active_population) manager._evolution_strategy._replacement_strategy.archive_history['run_stats'] = manager.stats save_obj(manager._evolution_strategy._replacement_strategy.archive_history, - os.path.join('..', 'enigma_logs', _args.exp_name), + os.path.join('.', 'logs', _args.exp_name), 'total_serialized_alg') elapsed = time.time() - start diff --git a/sample_args/args.yaml b/sample_args/args.yaml index 88860ca..65d47af 100644 --- a/sample_args/args.yaml +++ b/sample_args/args.yaml @@ -41,14 +41,14 @@ game_len: 500 opt_algo: "PPO" # algo generation args -evolution_timer: 25 +evolution_timer: 5 mutation_rate: 0.8 max_children: 3 max_envs: 10 comp_agent: mcts # algo general args -num_poet_loops: 2 +num_poet_loops: 100 transfer_timer: 10 snapshot_timer: 20 use_snapshot: False diff --git a/setup.py b/setup.py index 53170d1..4a9c149 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ author='Aaron Dharna , Charlie Summers , Rohin Dasari ', url='https://github.com/aadharna/watts', packages=['watts'], - python_requires='>=3.7.10', + python_requires='>=3.7.7', install_requires=[ 'ray[all]==1.6.0', 'griddly', diff --git a/snapshots/.gitkeep b/snapshots/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/watts/serializer/POETManagerSerializer.py b/watts/serializer/POETManagerSerializer.py index fbac0e0..81aaa8f 100644 --- a/watts/serializer/POETManagerSerializer.py +++ b/watts/serializer/POETManagerSerializer.py @@ -15,12 +15,12 @@ def __init__(self, manager: POETManager): self.manager = manager def serialize(self): - with open('snapshot.pkl', 'wb') as f: + with open('./snapshots/snapshot.pkl', 'wb') as f: pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) @staticmethod def deserialize() -> POETManager: - with open('snapshot.pkl', 'rb') as f: + with open('./snapshots/snapshot.pkl', 'rb') as f: return pickle.load(f).manager def __getstate__(self): diff --git a/watts/solvers/SingleAgentSolver.py b/watts/solvers/SingleAgentSolver.py index abe200c..8862c04 100644 --- a/watts/solvers/SingleAgentSolver.py +++ b/watts/solvers/SingleAgentSolver.py @@ -27,7 +27,7 @@ def __init__(self, trainer_constructor, trainer_config, registered_gym_name, net self.network_factory = network_factory self.gym_factory = gym_factory self.trainer = trainer_constructor(config=trainer_config, env=registered_gym_name, - logger_creator=custom_log_creator(os.path.join('..', 'enigma_logs', self.exp), + logger_creator=custom_log_creator(os.path.join('.', 'logs', self.exp), f'POET_{log_id}.') ) self.agent = network_factory.make()(weights) diff --git a/watts/utils/gym_wrappers.py b/watts/utils/gym_wrappers.py index 8efb1c7..9670202 100644 --- a/watts/utils/gym_wrappers.py +++ b/watts/utils/gym_wrappers.py @@ -351,6 +351,6 @@ def policy_mapping_fn(agent_id): stop = {"timesteps_total": 500000} results = tune.run(PPOTrainer, config=config2, stop=stop, - local_dir=os.path.join('..', 'enigma_logs'), checkpoint_at_end=True) + local_dir=os.path.join('.', 'logs'), checkpoint_at_end=True) ray.shutdown()