Skip to content

Conversation

@rrevoid
Copy link
Contributor

@rrevoid rrevoid commented Jul 17, 2023

  • refactoring

@rrevoid rrevoid requested a review from SpirinEgor July 17, 2023 15:51

str_conditions = []
if self.condition is Conditions.YOUR:
str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you ignoring the mypy here? What error does it raise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It raises Argument 1 to "join" of "str" has incompatible type "Optional[List[str]]"; expected "Iterable[str]" [arg-type] error. Although the condition guarantees the existence of the my_persona attribute.

self.max_condition_len = max_condition_len or max_context_len

self.have_candidates = have_candidates
self.have_candidates = not "no_cands" in path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this should be in the path? We use the same dataset for working with and without candidates

None, encoder_attention_mask.bool(), batch_size, src_len, self.num_heads, trg_len
).view(batch_size, self.num_heads, src_len, trg_len)
float_mask = torch.where(mask, 0, float("-inf"))
return self.out(self.attention(query, key, value, attn_bias=float_mask).view(batch_size, trg_len, emb_dim))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split this line into multiple expressions, please

Comment on lines 20 to 32
class Conditions(Enum):
NONE = 0
YOUR = 1
PARTNERS = 2


def get_condition(path: str):
if "none" in path:
return Conditions.NONE
elif "self" in path:
return Conditions.YOUR
else:
return Conditions.PARTNERS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is too complicated. Since we are only collecting personas in the collect function, let's pass the required arguments directly to it.

def collate_fn(..., return_my_persona, return_partner_persona)

And also let's return dict, e.g.,

{
    "context": self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs)
    "candidates":  self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs)
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants