objective : 기존 MedCLIP-SAMv2가 사용하는 손실함수에 엔트로피 제약항을 더해 새로운 손실함수를 만들고 zero-shot segmentation 성능을 확인한다.
BiomedCLIP는 메모리 절약을 위해 ViT 생성하는 어텐션 맵을 사용 직후 버리는데, entropy loss를 계산하기 위해 어텐션 맵을 계속 갖고 있어야 한다. 이를 해결하기 위해 factory.py 내에서 ViT의 함수를 오버라이딩해 어텐션 맵을 별도로 저장하게 만든다. 이러한 몽키 패치 기법은 라이브러리 내부 동작을 조금 수정하고 싶을 때 자주 사용한다.
프로그래밍, 특히 Python과 같은 동적 언어에서 런타임(Runtime) 중에 소스 코드를 직접 수정하지 않고 클래스나 모듈의 기능을 동적으로 변경하거나 확장하는 기법
ex. forward 메서드 가로채기 (Feature Extraction) : 특정 모델의 중간 레이어 출력을 보고 싶은데, 모델이 최종 출력만 내뱉도록 설계된 경우 forward 함수를 바꿔치기할 수 있다.
import timm
import torch
# 1. 모델 로드
model = timm.create_model('resnet18', pretrained=True)
# 2. 기존 forward 함수 저장 (나중에 원상복구 하거나 내부에서 호출할 수도 있음)
original_forward = model.forward
# 3. 새로운 forward 함수 정의 (중간 feature를 print하도록 수정)
def new_forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
print(f"Intermediate shape: {x.shape}") # 중간 확인
# ... 나머지 로직은 생략하거나 original_forward 호출 불가 시 직접 구현
# 여기서는 단순 예시로 원래 로직을 흉내
x = self.global_pool(self.layer4(self.layer3(self.layer2(self.layer1(self.maxpool(x))))))
x = self.fc(x)
return x
# 4. 몽키 패치 적용 (런타임에 메서드 교체)
# 바운드 메서드(Bound Method)로 만들기 위해 types.MethodType을 쓰거나,
# 단순히 인스턴스 레벨에서 함수를 할당해도 Python에서는 동작함 (단, self 처리 주의)
import types
model.forward = types.MethodType(new_forward, model)
# 5. 실행
dummy_input = torch.randn(1, 3, 224, 224)
output = model(dummy_input) # "Intermediate shape: ..." 출력됨
ex. MedCLIP-SAM 적용
# Monkey patch timm Attention to capture attention maps for Entropy Loss
try:
import timm.models.vision_transformer
from timm.models.vision_transformer import Attention as TimmAttention
def patched_attention_forward(self, x, attn_mask=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# Force manual attention to capture weights (disable fused_attn path)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
# Simple add, assuming mask is additive (e.g. -inf)
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
self.attn_map = attn # Capture attention map
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
TimmAttention.forward = patched_attention_forward
logging.info("Successfully patched timm.models.vision_transformer.Attention to capture attention maps.")
except ImportError:
logging.warning("Could not patch timm Attention. Entropy loss may not work for Timm models.")
except Exception as e:
logging.warning(f"Failed to patch timm Attention: {e}")
feature map
attention map