Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
./protos/pipeline_pb2.py
./protos/train_pb2.py
./protos/eval_pb2.py
mnist_data/*
53 changes: 45 additions & 8 deletions helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,29 @@ def _xentropy_loss_op(logit,label,name):
return:
mean cross entropy
"""
return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logit,
labels=label),name=name)
with tf.name_scope("LOSS"):
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logit,
labels=label),name=name)
tf.summary.scalar("LOSS", loss)
return loss

def _eval_op(logits,label,name):
"""
Construct accuracy operation
Args:
logits -> output units of network rank 0
label -> matching labels rank 0
name -> name scope of evaluation op
return:
mean accuracy op
"""
with tf.name_scope(name):
return tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits,label,1),tf.float32))
accuracy = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits,label,1),tf.float32))

tf.summary.scalar("ACCURACY", accuracy)
return accuracy


class Helper:
Expand Down Expand Up @@ -94,8 +101,10 @@ def get_acc(predictions,
batched_tensors: output of get_inputs
Returns:
accs_dict : dict mapping labels to accuracy operations
scalar_updates : any updates for keeping stats
"""
accs_dict = {}
scalar_updates = []
labels = {label:
tensor for label, tensor in batched_tensors.items()
if label != "input"}
Expand All @@ -105,7 +114,26 @@ def get_acc(predictions,
name=label)
accs_dict[label] = acc

return accs_dict
acc_avg = tf.train.ExponentialMovingAverage(0.9, name='moving_avg')
acc_avg_op = acc_avg.apply([acc])

tf.summary.scalar('acc_for_' + str(label), acc_avg.average(acc))
scalar_updates.append(acc_avg_op)


'''does the total acc summary stuff'''
sum_acc = 0.0
for k, v in accs_dict.items():
sum_acc += v

acc_avg = tf.train.ExponentialMovingAverage(0.9,name='moving_avg')
acc_avg_op = acc_avg.apply([sum_acc])

tf.summary.scalar('SUM_ACC',acc_avg.average(sum_acc))
scalar_updates.append(acc_avg_op)


return accs_dict, scalar_updates

@staticmethod
def get_loss(predictions,
Expand All @@ -117,9 +145,10 @@ def get_loss(predictions,
batched_tensors: output of get_inputs
Returns:
loss: the combined loss for all labels
scalar_updates : any updates for keeping stats
scalar_updates : any updates for keeping stats **last item in list is the total loss
"""
losses = []
scalar_updates = []
labels = {label:
tensor for label, tensor in batched_tensors.items()
if label != "input"}
Expand All @@ -130,6 +159,14 @@ def get_loss(predictions,
tf.losses.add_loss(loss,
tf.GraphKeys.LOSSES)
losses.append(loss)

loss_avg = tf.train.ExponentialMovingAverage(0.9, name='moving_avg')
loss_avg_op = loss_avg.apply([loss])
tf.summary.scalar('loss_for_' + str(label), loss_avg.average(loss))
scalar_updates.append(loss_avg_op)


'''does the total loss summary stuff'''
loss = tf.reduce_sum(losses,
name = "total_loss")

Expand All @@ -141,6 +178,6 @@ def get_loss(predictions,
#log loss and shadow variables for avg loss
#tf.summary.scalar(loss.op.name+' (raw)',loss)
tf.summary.scalar(loss.op.name,loss_avg.average(loss))
scalar_updates = [loss_avg_op]
scalar_updates.append(loss_avg_op)

return loss, scalar_updates