Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions wechat.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
from io import open

import itchat
import tensorflow as tf
from itchat.content import TEXT, MAP, CARD, NOTE, SHARING

import itchat
from itchat.content import *
import config
import data_utils
from model import ChatBotModel
from chatbot import check_restore_parameters
from chatbot import construct_response
from chatbot import find_right_bucket
from chatbot import run_step
from model import ChatBotModel
from chatbot import construct_response

sess = tf.InteractiveSession()
_, enc_vocab = data_utils.load_vocab(os.path.join(config.DATA_PATH, "vocab.enc"))
Expand All @@ -30,17 +26,34 @@
print("Wechat server started. Say something. Max length is", max_length)


@itchat.msg_register([TEXT, MAP, CARD, NOTE, SHARING],
isFriendChat=True, isGroupChat=False, isMpChat=False)

# reply to friends, group or public account
@itchat.msg_register([TEXT,PICTURE,VIDEO], isFriendChat=True, isGroupChat=False, isMpChat=True)
def wechat(msg):
"""
Args:
msg: can be TEXT, MAP, CARD NOTE and SHARING of wechat.
Return:
wechat response.
"""
line = msg["Text"]
token_ids = data_utils.sentence2id(enc_vocab, line)
'''
reply according to different types of messages
:param msg: msg can be text and picture, for other types of message, reply '爸爸!'
:return: response to wechat
'''
if msg['Type'] == 'Text':
response = wechat_text(msg)
elif msg['Type'] == 'Picture':
response = wechat_pic(msg)
else:
response = '爸爸!'
query = msg['Text']
write_wechat_records(query, response)
return response


def wechat_text(msg):
'''
get the response to msg
:param msg: the type of the msg is 'Text'
:return: response to msg by using seq2seq model
'''
text = msg["Text"]
token_ids = data_utils.sentence2id(enc_vocab, text)
if len(token_ids) > max_length:
print("Max length I can handle is:", max_length)
# line = _get_user_input()
Expand All @@ -54,13 +67,21 @@ def wechat(msg):
decoder_inputs, decoder_masks,
bucket_id, True)
response = construct_response(output_logits, inv_dec_vocab)
write_wechat_records(line, response)

return response


def wechat_pic(msg):
'''
get the response to msg
:param msg: the type of the msg is 'Picture'
:return: response to picture
'''
return '爸爸!'

# write chat records to file
def write_wechat_records(query, response):
output_file = open(os.path.join(config.DATA_PATH, config.WECHAT_OUTPUT),
output_file = open(os.path.join(config.DATA_PATH, config.WECHAT_FILE),
"a+", encoding="utf-8")
output_file.write("HUMAN ++++ " + query + "\n")
output_file.write("BOT ++++ " + response + "\n")
Expand Down