2024. 2. 20. 13:52ㆍ데이터 분석
오늘은 학습 모델의 평가 지표로 사용되는 Confusion Matrix를 통해 얻을 수 있는 평가 지표들에 대해서 그리고 Python에서 어떻게 생성할 수 있는지 그리고 어떻게 해석할 수 있는지에 대해서 정리하고자 합니다.
1. Confusion Matrix
Confusion Matrix, 혼동 행렬은 학습모델 중에서 지도학습 모델에 대한 이진분류 문제에서 많이 사용합니다. 혼동 행렬은 실제 값과 모델 예측 값을 기준으로 생성되며, 다음 네 가지 요소를 가지고 있습니다
- 참 양성(True Positive) : 실제 양성을 양성으로 올바르게 예측
- 거짓 양성(Flase Positive) : 실제 음성을 잘못하여 양성으로 예측
- 참 음성(True Negative) : 실제 음성을 음성으로 올바르게 예측
- 거짓 음성(False Negative) : 실제 양성을 잘못하여 음성으로 예측
다음 네 가지요소를 대문자만으로 TP,FP,TN,FN으로 나타내기도 합니다. 그리고 이러한 값들을 통해서 다음 지표들을 얻을 수 있습니다.
- 정확도(Accuracy) : 전체 예측 중 올바르게 예측된 비율. $$ \frac{TP+TN}{TP+TN+FP+FN}$$
- 정밀도(Prexision) : 양성으로 예측된 것들 중 실제로 양성인 것의 비율. $$\frac{TP}{TP+FP}$$
- 재현율(Recall) : 실제 양성 중 양성으로 올바르게 예측된 비율. $$\frac{TP}{TP+FN}$$
- F1-Score : 정밀도와 재현율의 조화 평균. $$2\times \frac{Precision\times Recall}{Precision+Recall}$$
보통의 상황에서는 정확도가 가장 중요한 지표처럼 보이지만 데이터의 불균형이 있을 경우 또는 특정 오류에 대해 민감한 문제가 발생하는 경우 정밀도 또는 재현율이 좋은 지표가 될 수 있습니다.
- 의료진단(재현율 중요) : 질병을 진단하는 경우, 실제로 질병을 가진 사람을 건강하다고 잘못 판단하는 것의 위험은 매우 높습니다. 이 경우, 질병이 있는 사람을 식별하는 것이 중요하기 때문에 재현율이 중요합니다.
- 스팸 이메일 필터링(정밀도 중요): 스팸 필터는 스팸이 아닌 이메일을 스팸으로 잘못 분류하는 것(거짓 양성)을 최소화해야 합니다. 중요한 이메일이 스팸으로 분류되는 문제는 크기 때문에 이 경우 정밀도가 중요합니다.
이번엔 계산하는 방법에 대한 예시를 들어보겠습니다. 예를 들어, 어떤 질병을 진단하는 모델이 100개의 샘플을 평가했고, 다음과 같은 결과를 얻었다고 가정하겠습니다.
TP = 30, FP = 20, TN = 40, FN = 10
- 정확도$$ \frac{TP+TN}{TP+TN+FP+FN} =\frac{30+40}{30+20+40+10}=0.7$$
- 정밀도$$ \frac{TP}{TP+FP} =\frac{30}{30+20}=0.60$$
- 재현율$$ \frac{TP}{TP+FN} =\frac{30}{30+10}=0.75$$
- F1 점수$$ 2\times \frac{Precision\times Recall}{Precision+Recall} =2\times \frac{0.60\times0.75}{0.60+0.75}\approx 0.67$$
2. Confusion Matrix in Python
만드는 것은 간단하게만 만들겠습니다. 데이터에 대해 모델 학습에 대한 부분은 우선 제외시켜놓고 하겠습니다. 우선 필요한 라이브러리들을 불러옵니다.
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
다음으로, 예시를 위해 실제 레이블과 예측 레이블을 정의합니다. 저는 예시를 위해 작성했으며, 실제로 모델을 학습 시킨다고 하면 실제 레이블은 test data의 실제 값과 예측 레이블은 test data의 모델 예측 값이 들어가게 됩니다.
y_true = [1, 0, 1, 1, 0, 1, 0, 0, 1, 0]
y_pred = [1, 0, 1, 0, 0, 1, 1, 0, 0, 0]
아래는 이제 혼동 행렬을 생성하고 각 평가 지표들에 대한 계산입니다.
# 혼동 행렬 생성
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)
# 정확도 계산
accuracy = accuracy_score(y_true, y_pred)
print("\nAccuracy:", accuracy)
# 재현율 계산
recall = recall_score(y_true, y_pred)
print("Recall:", recall)
# 정밀도 계산
precision = precision_score(y_true, y_pred)
print("Precision:", precision)
# F1 점수 계산
f1 = f1_score(y_true, y_pred)
print("F1 Score:", f1)
'데이터 분석' 카테고리의 다른 글
[데이터 분석]보건의료빅데이터분석 데이터 전처리 (0) | 2024.01.19 |
---|---|
[데이터 분석] Kaggle Data Report (0) | 2024.01.16 |