import numpy as np
import torch
import torch.nn as nn

class OhemCELoss(nn.Module):

    def __init__(self, thresh, ignore_lb=255):
        super(OhemCELoss, self).__init__()
        self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float))
        self.ignore_lb = ignore_lb
        self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')

    def forward(self, logits, labels):
        n_min = labels[labels != self.ignore_lb].numel() // 16
        loss = self.criteria(logits, labels).view(-1)
        loss_hard = loss[loss > self.thresh]
        if loss_hard.numel() < n_min:
            loss_hard, _ = loss.topk(n_min)
        return torch.mean(loss_hard)

if __name__ == "__main__":
    # logit.shape:[2,13,320,640]
    logit = np.random.random((2, 13, 320, 640))
    target1 = np.random.randint(0, 13, size=(320, 640))
    target1 = target1[np.newaxis, :, :]
    target2 = np.random.randint(0, 13, size=(320, 640))
    target2 = target2[np.newaxis, :, :]
    # target.shape:[2,320,640]
    target = np.vstack([target1, target2])

    # numpy --> tensor
    logit = torch.tensor(logit)
    target = torch.tensor(target).long()

    # loss forword
    F = OhemCELoss(thresh = 0.7)
    loss = F.forward(logit, target)
    print(loss)

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐