MNIST 데이터셋 검증 - 오차행렬

1. 오차 행렬

분류기의 성능을 평가하는 더 좋은 방법은 오차 행렬(Confusion Matrix)을 조사하는 것입니다. 클래스 A의 샘플이 클래스 B로 분류된 횟수를 세는 것입니다.

오차 행렬을 만들기 위해서는 실제 타깃과 비교할 수 있도록 먼저 예측값을 만들어야 합니다. 테스트 세트로 예측을 만들 수 있지만 여기서 사용해서는 안됩니다. 대신 cross_val_predict() 함수를 사용할 수 있습니다.

Code

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
print(confusion_matrix(y_train_5, y_train_pred))

Output

54,125개를 '5 아님'으로 정확하게 분류하였고(TN), 나머지 454개는 '5'라고 잘못 분류하였습니다(FP).
1,567개를 '5 아님'으로 잘못 분류하였고(FN), 나머지 3,854개를 정확히 '5'라고 분류하였습니다(TP). 

완벽한 분류기라면 진짜 양성진짜 음성만 가지고 있을 것이므로 오차 행렬의 주대각선만 0이 아닌 값이 됩니다.



오차 행렬이 많은 정보를 제공해주지만, 가끔 더 요약된 지표가 필요할 때도 있습니다.
살펴볼만한 것 하나는 양성 예측의 정확도입니다. 이를 분류기의 정밀도(Precision)라고 합니다. 


< Exp. 1. 정밀도 공식 >

TP는 진짜 양성의 수이고, FP는 거짓 양성의 수입니다.

확실한 양성 샘플 하나만 예측하면 간단히 완벽한 정밀도를 얻을 수 있지만, 이는 분류기가 다른 모든 양성 샘플을 무시하기 때문에 그리 유용하지 않습니다. 정밀도는 재현율(Recall) 이라는 또 다른 지표와 같이 사용하는 것이 일반적입니다. 재현율은 분류기가 정확하게 감지한 양성 샘플의 비율로, 민감도(Sensitivity), TPR(True Positive Rate) 이라고도 합니다.



2. 정밀도와 재현율

사이킷런에서는 정밀도와 재현율을 포함하여 분류기의 지표를 계산하는 여러 함수를 제공하고 있습니다.

Code

from sklearn.metrics import precision_score, recall_score
print(precision_score(y_train_5, y_train_pred))
print(recall_score(y_train_5, y_train_pred))

Output

5로 판별된 이미지 중 74%만 정확합니다. 전체 숫자에서는 85%만 감지하였습니다.



정밀도와 재현율을 F₁ 점수라고 하는 하나의 숫자로 만들면 편리할 때가 많습니다. 특히 두 분류기를 비교할 때 유용합니다.
F₁ 점수는 정밀도와 재현율의 조화 평균(Harmonic Mean)입니다.


< Exp. 2. F₁ 점수 공식 >


F₁ 점수를 계산하기 위해서는 f1_score() 함수를 호출하면 됩니다.

Code

from sklearn.metrics import f1_score
print(f1_score(y_train_5, y_train_pred))

Output

정밀도와 재현율이 비슷한 분류기에서는 F₁ 점수가 높습니다. 하지만 이게 항상 바람직한 것은 아닙니다. 상황에 따라 정밀도가 중요할 수도 있고 재현율이 중요할 수도 있습니다.
아쉽지만 정밀도와 재현율을 둘다 얻을 수는 없습니다. 정밀도를 올리면 재현율이 줄고 그 반대도 마찬가지 입니다. 이를 정밀도 / 재현율 트레이드오프라고 합니다.



References

  • 오렐리앙 제롱, '핸즈온 머신러닝', 한빛미디어, 2018


+ Recent posts