diff --git a/sc2/client.py b/sc2/client.py index 249fa877..2c095b8f 100644 --- a/sc2/client.py +++ b/sc2/client.py @@ -8,6 +8,7 @@ from s2clientprotocol import raw_pb2 as raw_pb from s2clientprotocol import sc2api_pb2 as sc_pb from s2clientprotocol import spatial_pb2 as spatial_pb +from s2clientprotocol import ui_pb2 as ui_pb from sc2.action import combine_actions from sc2.data import ActionResult, ChatChannel, Race, Result, Status @@ -44,8 +45,9 @@ def __init__(self, ws, save_replay_path: str = None): self._debug_spheres = [] self._renderer = None - self.raw_affects_selection = False - + self.raw_affects_selection = True + self.enable_feature_layer = True + @property def in_game(self) -> bool: return self._status in {Status.in_game, Status.in_replay} @@ -118,7 +120,46 @@ async def leave(self): except (ProtocolError, ConnectionAlreadyClosed): if is_resign: raise + + async def unload_unit(self, transporter_unit: Unit, cargo_unit: Unit = False): + """ + Unloads single unit passed by cargo_unit or first unit in transporter if not cargo_unit passed + transporter_unit includes all units, which can unload cargo: + warp prism, medivac, bunker, command center, command center flying, planetary fortress, droppenlord, nydlus + + Usage: + self.client.unload_unit(transporter_unit, cargo_unit) + self.client.unload_unit(transporter_unit) # unloads first one + + """ + assert isinstance(transporter_unit, Unit) + assert isinstance(cargo_unit, (bool, Unit)) + + if not transporter_unit.passengers: + return + + if isinstance(cargo_unit, bool): + unload_unit_index = 0 + if isinstance(cargo_unit, Unit): + unload_unit_index = next((index for index, unit in enumerate(transporter_unit._proto.passengers) if unit.tag == cargo_unit.tag), 0) + await self._execute( + action=sc_pb.RequestAction( + actions=[ + sc_pb.Action( + action_raw=raw_pb.ActionRaw( + unit_command=raw_pb.ActionRawUnitCommand(ability_id=0, unit_tags=[transporter_unit.tag]) + ) + ), + sc_pb.Action( + action_ui=ui_pb.ActionUI( + cargo_panel=ui_pb.ActionCargoPanelUnload(unit_index=unload_unit_index) + ) + ), + ] + ) + ) + async def save_replay(self, path): logger.debug("Requesting replay from server") result = await self._execute(save_replay=sc_pb.RequestSaveReplay())