diff --git a/wechat.py b/wechat.py index b9d81e8..9c4d360 100644 --- a/wechat.py +++ b/wechat.py @@ -1,21 +1,17 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals - import os -from io import open - -import itchat import tensorflow as tf -from itchat.content import TEXT, MAP, CARD, NOTE, SHARING - +import itchat +from itchat.content import * import config import data_utils +from model import ChatBotModel from chatbot import check_restore_parameters -from chatbot import construct_response from chatbot import find_right_bucket from chatbot import run_step -from model import ChatBotModel +from chatbot import construct_response sess = tf.InteractiveSession() _, enc_vocab = data_utils.load_vocab(os.path.join(config.DATA_PATH, "vocab.enc")) @@ -30,17 +26,34 @@ print("Wechat server started. Say something. Max length is", max_length) -@itchat.msg_register([TEXT, MAP, CARD, NOTE, SHARING], - isFriendChat=True, isGroupChat=False, isMpChat=False) + +# reply to friends, group or public account +@itchat.msg_register([TEXT,PICTURE,VIDEO], isFriendChat=True, isGroupChat=False, isMpChat=True) def wechat(msg): - """ - Args: - msg: can be TEXT, MAP, CARD NOTE and SHARING of wechat. - Return: - wechat response. - """ - line = msg["Text"] - token_ids = data_utils.sentence2id(enc_vocab, line) + ''' + reply according to different types of messages + :param msg: msg can be text and picture, for other types of message, reply '爸爸!' + :return: response to wechat + ''' + if msg['Type'] == 'Text': + response = wechat_text(msg) + elif msg['Type'] == 'Picture': + response = wechat_pic(msg) + else: + response = '爸爸!' + query = msg['Text'] + write_wechat_records(query, response) + return response + + +def wechat_text(msg): + ''' + get the response to msg + :param msg: the type of the msg is 'Text' + :return: response to msg by using seq2seq model + ''' + text = msg["Text"] + token_ids = data_utils.sentence2id(enc_vocab, text) if len(token_ids) > max_length: print("Max length I can handle is:", max_length) # line = _get_user_input() @@ -54,13 +67,21 @@ def wechat(msg): decoder_inputs, decoder_masks, bucket_id, True) response = construct_response(output_logits, inv_dec_vocab) - write_wechat_records(line, response) + return response +def wechat_pic(msg): + ''' + get the response to msg + :param msg: the type of the msg is 'Picture' + :return: response to picture + ''' + return '爸爸!' + # write chat records to file def write_wechat_records(query, response): - output_file = open(os.path.join(config.DATA_PATH, config.WECHAT_OUTPUT), + output_file = open(os.path.join(config.DATA_PATH, config.WECHAT_FILE), "a+", encoding="utf-8") output_file.write("HUMAN ++++ " + query + "\n") output_file.write("BOT ++++ " + response + "\n")