-
Notifications
You must be signed in to change notification settings - Fork 0
Add adapter model for conditioning #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
rrevoid
commented
Jul 17, 2023
- refactoring
src/data/convai2_dataset.py
Outdated
|
|
||
| str_conditions = [] | ||
| if self.condition is Conditions.YOUR: | ||
| str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you ignoring the mypy here? What error does it raise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It raises Argument 1 to "join" of "str" has incompatible type "Optional[List[str]]"; expected "Iterable[str]" [arg-type] error. Although the condition guarantees the existence of the my_persona attribute.
| self.max_condition_len = max_condition_len or max_context_len | ||
|
|
||
| self.have_candidates = have_candidates | ||
| self.have_candidates = not "no_cands" in path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this should be in the path? We use the same dataset for working with and without candidates
src/conditioning/adapter.py
Outdated
| None, encoder_attention_mask.bool(), batch_size, src_len, self.num_heads, trg_len | ||
| ).view(batch_size, self.num_heads, src_len, trg_len) | ||
| float_mask = torch.where(mask, 0, float("-inf")) | ||
| return self.out(self.attention(query, key, value, attn_bias=float_mask).view(batch_size, trg_len, emb_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split this line into multiple expressions, please
src/data/convai2_dataset.py
Outdated
| class Conditions(Enum): | ||
| NONE = 0 | ||
| YOUR = 1 | ||
| PARTNERS = 2 | ||
|
|
||
|
|
||
| def get_condition(path: str): | ||
| if "none" in path: | ||
| return Conditions.NONE | ||
| elif "self" in path: | ||
| return Conditions.YOUR | ||
| else: | ||
| return Conditions.PARTNERS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is too complicated. Since we are only collecting personas in the collect function, let's pass the required arguments directly to it.
def collate_fn(..., return_my_persona, return_partner_persona)And also let's return dict, e.g.,
{
"context": self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs)
"candidates": self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs)
}