Sparse Autoencoder as a Zero-Shot Classifier for Concept Erasing in Text-to-Image Diffusion Models 这篇论文提出了一种名为Interpret then Deactivate (ItD) 的框架,旨在文本到图像(T2I)扩散模型中实现精准、可扩展的概念擦除(即移除不想要的概念,如有害内容、特定名人等),同时不影响正常概念的生成。以下从思路、方法原理、数学公式推导三方面详细总结:
一、核心思路
现有概念擦除方法存在两大局限:1)微调模型参数会导致正常概念生成质量下降;2)集成定制模块泛化能力弱且需额外训练。为此,ItD框架通过“解释-停用”两步解决问题:
- 解释(Interpret):用稀疏自编码器(SAE)将概念分解为稀疏特征的线性组合,明确概念的特征构成;
- 停用(Deactivate):仅停用目标概念特有的特征(排除与正常概念共享的特征),实现精准擦除,同时保留正常概念的生成能力。
此外,SAE被复用为零样本分类器,可判断输入是否包含目标概念,仅在必要时应用擦除,进一步减少对正常概念的影响。
二、方法原理
1. 稀疏自编码器(SAE)的训练
SAE的作用是将文本编码器的语义信息(残差流输出)分解为稀疏特征的组合,为概念“解释”提供基础。
- 训练对象:文本编码器中transformer块的残差流输出(即token嵌入 (e_l^h)),其中 (l)为层索引,(h)为token索引)。
- 模型选择:采用K-稀疏自编码器(KSAE),强制每次重构仅保留K个最大激活特征,确保稀疏性。
- 核心目标:使SAE能将输入的token嵌入 (e)重构为稀疏特征的线性组合,即 (e approx sum_{rho=1}^{d_{hid}} z^rho f_rho)((z^rho)为特征激活值,(f_rho)为解码器矩阵的列向量,即特征向量)。
2. 特征选择:定位目标概念的“特有特征”
为避免擦除正常概念的特征,需筛选出目标概念特有的特征:
- 步骤1:收集目标概念的相关特征
对目标概念的每个token,通过SAE获取特征激活值,取每个特征在所有token中的最大激活值,筛选出前 (K_{sel})个高激活特征,构成目标特征集 (F_{tar})。 - 步骤2:排除与正常概念共享的特征
用正常概念集(retain set)的特征集 (F_{retain})与 (F_{tar})对比,移除两者共有的特征,得到目标概念特有特征集 (hat{F}_{tar} = F_{tar} setminus bigcup F_{retain})。 - 多概念擦除:对多个目标概念,取其特有特征集的并集 (F_{erase} = bigcup hat{F}_{tar})。
3. 概念擦除机制
通过调整目标特征的激活值,移除文本嵌入中目标概念的语义信息:
- 编码与调整:文本嵌入经SAE编码为特征激活 (s)后,将 (F_{erase})中的特征激活值缩放(如乘以小系数 (tau)),削弱其影响;
- 解码重构:调整后的激活经SAE解码器重构为新的文本嵌入,该嵌入不再包含目标概念的信息,从而阻止扩散模型生成相关图像。
4. 零样本分类器:选择性擦除
利用SAE的重构损失判断输入是否包含目标概念,仅在包含时应用擦除,减少对正常概念的干扰:
- 若输入文本嵌入 (e)包含目标概念,其经SAE重构后的误差 (|e - hat{e}|)较小(因目标特征被调整);
- 若为正常概念,重构误差较大。通过阈值 (tau)区分:(G(e) = 1)(含目标概念,应用擦除)若 (|e - hat{e}|^2 < tau),否则为0(不擦除)。
三、数学公式原理推导
1. SAE的编码器与解码器
-
编码器:将输入token嵌入 (e)转换为稀疏特征激活 (z)
[z = text{TopK}(W_{enc}(e - b_{pre})) ]其中,(W_{enc} in mathbb{R}^{d_{hid} times d_{in}})为编码器权重,(b_{pre})为偏置,(text{TopK})保留前K个最大激活值(其余置0),确保稀疏性。
-
解码器:将特征激活 (z)重构为嵌入 (hat{e})
[hat{e} = W_{dec} z + b_{pre} ]其中,(W_{dec} in mathbb{R}^{d_{in} times d_{hid}})为解码器权重,每列 (f_rho)为特征向量。
2. SAE的训练损失函数
目标是最小化重构误差并保证特征稀疏性,损失函数为:
- 第一项 (|e - hat{e}|_2^2):L2重构损失,确保输入与输出接近;
- 第二项 (alpha mathcal{L}_{aux}):辅助损失,防止“死特征”(即极少激活的特征)。(mathcal{L}_{aux})定义为使用前 (K_{aux})((K_{aux} > K))个特征的重构误差,(alpha)为权重系数。
3. 特征选择公式
-
目标概念特征集 (F_{tar}):
[F_{tar} = {rho mid s_C^rho in text{TopK}(s_C^1, ..., s_C^{d_{hid}})} ]其中 (s_C^rho = max(s_1^rho, ..., s_H^rho)),(s_h^rho)为第h个token的第(rho)个特征激活值,(H)为目标概念的token数。
-
特有特征集 (hat{F}_{tar}):
[hat{F}_{tar} = F_{tar} setminus bigcup_{C_r in C_{retain}} F_{C_r} ]其中 (C_{retain})为正常概念集,(F_{C_r})为正常概念的特征集。
当需要擦除多个目标概念时,总擦除特征集为各目标特有特征集的并集:
其中(mathcal{C}_{tar})为所有目标概念的集合,该式确保一次擦除多个概念且无需额外训练。
4. 概念擦除的激活调整
对特征激活 (s)进行调整,削弱目标特征的影响:
其中 (tau)为缩放系数(如 (tau < 1),削弱激活),调整后通过解码器重构为 (hat{e} = W_{dec} hat{s} + b_{pre})。
5. 零样本分类器的判断公式
基于重构损失判断是否包含目标概念:
其中 (tau)为阈值,(G(e)=1)时应用擦除,否则直接输出原嵌入。
四、总结
ItD框架通过SAE将概念分解为稀疏特征,结合对比特征选择和选择性擦除,实现了精准、可扩展的概念擦除。数学公式确保了SAE的稀疏性、特征的特异性及擦除的针对性,解决了现有方法对正常概念生成的干扰问题,同时通过零样本分类器进一步提升了鲁棒性。