OHEM loss 源代码
ohem loss
·
import numpy as np
import torch
import torch.nn as nn
class OhemCELoss(nn.Module):
def __init__(self, thresh, ignore_lb=255):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float))
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
def forward(self, logits, labels):
n_min = labels[labels != self.ignore_lb].numel() // 16
loss = self.criteria(logits, labels).view(-1)
loss_hard = loss[loss > self.thresh]
if loss_hard.numel() < n_min:
loss_hard, _ = loss.topk(n_min)
return torch.mean(loss_hard)
if __name__ == "__main__":
# logit.shape:[2,13,320,640]
logit = np.random.random((2, 13, 320, 640))
target1 = np.random.randint(0, 13, size=(320, 640))
target1 = target1[np.newaxis, :, :]
target2 = np.random.randint(0, 13, size=(320, 640))
target2 = target2[np.newaxis, :, :]
# target.shape:[2,320,640]
target = np.vstack([target1, target2])
# numpy --> tensor
logit = torch.tensor(logit)
target = torch.tensor(target).long()
# loss forword
F = OhemCELoss(thresh = 0.7)
loss = F.forward(logit, target)
print(loss)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献2条内容
所有评论(0)