Pytorch和Tensorflow中的交叉熵损失函数

  • 导入支持

  • 生成测试数据

    labels相当于真实的分类数据,其中onehot_labels是对类别号的标记方式进行的onehot处理;logits是网络生成的预测数据

  • 在TensorFlow中

    其中用到了tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tflabels, logits=tflogits)和tf.nn.softmax_cross_entropy_with_logits(labels=tflabels_oh, logits=tflogits)两个方法,两者的主要区别是前者传入的labels可以是直接的数字类别标记,而后者的传入onehot化之后的labels

  • 在Pytorch中

    其中用到了两种方式实现,两者没有差别,后者是前者的内部实现方式;其中reduce的作用是是否对结果进一步处理,不过不设定,会默认输出当前结果的平均值(就只有一个值了)

  • 结果分析

    print结果如下

    可以看到,结果都是一样的

  • 说明

    TensorFlow版本1.14

  • 参考文献

    【TensorFlow】关于tf.nn.sparse_softmax_cross_entropy_with_logits()

    pytorch笔记:03)softmax和log_softmax,以及CrossEntropyLoss

You may also like...

发表评论

电子邮件地址不会被公开。 必填项已用*标注