-def nb_rank_error(output, targets):
- output = output.reshape(-1, output.size(-1))
- targets = targets.reshape(-1, targets.size(-1))
- i = outputs.argmax(1)
- # out=input.gather out[i][j]=input[i][index[i][j]]
- # u[k]=targets[k][i[k]]
- return output[targets.argmax(1)]
-
-