Deep Learning/XAI

Neural Additive Models: Interpretable Machine Learning with Neural Nets 논문 리뷰

논문 : https://arxiv.org/pdf/2004.13912.pdf


0. Abstract

  DNN(Deep neural network)은 인상적인 성능을 내지만 어떻게 결정을 내리는지 대게 불분명한 black-box predictor 이다. 이러한 성격 때문에 의료(healthcare) 분야와 같이 위험부담이 큰 영역에 대한 적용을 방해한다. 따라서 본 논문은 DNN의 표현성(expressivity)과 *일반화 가법모델(Generalized Additive Models)의 고유한 해석성(inherent intelligibility)을 결합한 NAM(Neural Addition Models)를 제안한다. NAM은 각각의 single input feature를 처리하는 신경망의 선형 결합(linear combination)을 학습한다. 실험을 통해 널리 사용되는 logistic regression과 shallow decision trees 보다 NAM이 더 정확하다는 것을 확인했다. 

 

* 일반화가법모형 (Generalized Additive Model, GAM)은 일반화선형모형(Generalized Linear Model, GLM)을 확장시킨 것으로, GLM에서 독립변수에 대해 적용되었던 선형 관계를 GAM에서는 비모수적 함수를 이용해 비선형적으로 표현할 수 있다.                             출처: https://drnq.tistory.com/245 [Dilettante Zen]

 

1. Introduction

  본 논문에서는 신경망 구조에 제한을 두어 NAM(Neural Addition Models)이라는 모델군을 생성하는데, 이는 표 형식 데이터에 적용될 때 예측의 정확도가 거의 손실되지는 않지만 본질적으로(inherently) 해석가능한(interpretable) 모델이다. 

binary classification을 위한 NAM 구조. 각각의 input 변수가 다른 신경망에 의해 처리된다.                           이것은 쉽게 해석할 수 있지만 매우 정확한 모델을 만든다.

  방법론적으로, NAMs는 GAMs(Generalized Additive Models)에 속한다.

  여기서 x는 k개의 특징이 있는 입력 값이고, y는 target variable이고 g(.)는 link function (e.g. logistic function),  f_i 는 다음과 같은 단일 변수 형상 함수(shape function)이다.  E[f_i] = 0. 

NAMs는 각각의 single input feature를 처리하는 네트워크의 linear combination을 학습한다. 각각의 f_i는 신경망에 의해 매개변수화(parametrized) 된다. 이런 네트워크는 backpropagation을 사용하여 공동으로 훈련되며 임의로 복잡한 형상 함수(shape function)를 학습할 수 있다. 예측(prediction)에 대한 feature의 영향이 다른 feature에 의존하지 않기 때문에 NAMs를 해석하는 것은 쉬우며 해당 shape function을 시각화하여 이해할 수 있다.

 

  NAMs의 장점은 다음과 같다.

2. Neural Additive Models 

 

2.1 Fitting Jagged Shape Functions 

  ReLUs 와 표준 초기화 함수(Kaiming initialization, Xavier initialization)로 과잉 파라미터화(over-parameterized)된 NNs는 NN 아키텍처가 충분히 expressive 되어 있음에도 불구하고 mini-batch gradient descent를 사용해서 학습시킬 때 이 toy dataset을 overfit 하려고 애쓴다.

  들쭉날쭉한 함수(jagged functions)를 학습할 때 그들의 global behavior에 영향을 주지 않고 ReLU 네트워크와 함께 대규모의 local한 변동(fluctuation)을 학습하는 것의 어려움은 smooth function을 배우려는 그들의 bias 때문이다.

  우리는 이러한 신경망 장애를 극복하기 위해 exp-centered (ExU) hidden units를 제안한다: 우리는 단순히 편향(bias)에 의해 이동된(shifted) inputs을 가진 로그 공간의 가중치를 배운다. 특히 스칼라 입력 x의 경우 활성화 함수를 사용하는 각 hidden unit는 다음과 같이 h(x)를 계산한다.

ExU는 neural nets의 표현력을 향상시켜주진 않지만 jumpy function을 fitting시키기 위한 학습성을 향상한다.

2.2 Regularization and Training

overfitting을 피하기 위해 다음과 같은 정규화 기술을 사용한다.

  •  Weight decay (가중치 감쇠) : Regularization 방법으로 특정 weight이 큰 값을 갖지 못하도록 하여 over-fitting을 방지하는 방법으로 사용된다.

[참고] 과적합(Overfitting)과 규제, 드랍아웃, 정규화

https://github.com/PolarJE/AI-summer-camp/blob/master/15.%20%EA%B3%BC%EC%A0%81%ED%95%A9(Overfitting)%EA%B3%BC%20%EA%B7%9C%EC%A0%9C%2C%20%EB%93%9C%EB%9E%8D%EC%95%84%EC%9B%83%2C%20%EC%A0%95%EA%B7%9C%ED%99%94.ipynb

 

 

2.3 Intelligibility and Modularity

NAMs의 intelligibility(명료성)은 부분적으로 쉽게 시각화 할 수 있기 때문이다. 각 feature는 신경망에 의해 parameterized된 학습된 형상 함수에 의해 독립적으로 처리되기 때문에 개별 형상 함수를 간단히 그래프로 표시하여 모델 전체를 볼 수 있습니다. input이 적은 data의 경우 모델의 동작에 대해 이해하기 쉬운 설명을 한 페이지에 완전히 시각화 할 수가 있다. 이런 형상 함수 플롯(shape function plot)은 단순한 설명이 아니라 NAMs이 예측(prediction)을 계산하는 방법에 대한 정확한 설명이다.

 

3. Experiments

NAMs가 다음과 같은 baseline에서 성능이 더 좋음을 보였다. 

 

3.1 Classification 

 

3.1.1 COMPAS : Risk Prediction in Criminal Justice