mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-12 20:18:08 +00:00
minor : improve C++ and Python style (#768)
* use some STL functions * use self.field than setattr, use pathlib.Path * recover some format * const some iter * Keep the original * 2 space
This commit is contained in:
@ -20,7 +20,7 @@ def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
||||
"""
|
||||
for k in state_dict:
|
||||
is_attention = all(substr in k for substr in ['attn', '.weight'])
|
||||
is_mlp = any([k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight']])
|
||||
is_mlp = any(k.endswith(s) for s in ['mlp.0.weight', 'mlp.2.weight'])
|
||||
|
||||
if (is_attention or is_mlp) and len(state_dict[k].shape) == 2:
|
||||
state_dict[k] = state_dict[k][:, :, None, None]
|
||||
@ -42,11 +42,10 @@ class LayerNormANE(LayerNormANEBase):
|
||||
class MultiHeadAttentionANE(MultiHeadAttention):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__(n_state, n_head)
|
||||
|
||||
setattr(self, 'query', nn.Conv2d(n_state, n_state, kernel_size=1))
|
||||
setattr(self, 'key', nn.Conv2d(n_state, n_state, kernel_size=1, bias=False))
|
||||
setattr(self, 'value', nn.Conv2d(n_state, n_state, kernel_size=1))
|
||||
setattr(self, 'out', nn.Conv2d(n_state, n_state, kernel_size=1))
|
||||
self.query = nn.Conv2d(n_state, n_state, kernel_size=1)
|
||||
self.key = nn.Conv2d(n_state, n_state, kernel_size=1, bias=False)
|
||||
self.value = nn.Conv2d(n_state, n_state, kernel_size=1)
|
||||
self.out = nn.Conv2d(n_state, n_state, kernel_size=1)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
@ -104,30 +103,28 @@ class MultiHeadAttentionANE(MultiHeadAttention):
|
||||
class ResidualAttentionBlockANE(ResidualAttentionBlock):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__(n_state, n_head, cross_attention)
|
||||
|
||||
setattr(self, 'attn', MultiHeadAttentionANE(n_state, n_head))
|
||||
setattr(self, 'attn_ln', LayerNormANE(n_state))
|
||||
|
||||
setattr(self, 'cross_attn', MultiHeadAttentionANE(n_state, n_head) if cross_attention else None)
|
||||
setattr(self, 'cross_attn_ln', LayerNormANE(n_state) if cross_attention else None)
|
||||
self.attn = MultiHeadAttentionANE(n_state, n_head)
|
||||
self.attn_ln = LayerNormANE(n_state)
|
||||
self.cross_attn = MultiHeadAttentionANE(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn_ln = LayerNormANE(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
setattr(self, 'mlp', nn.Sequential(
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Conv2d(n_state, n_mlp, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(n_mlp, n_state, kernel_size=1)
|
||||
))
|
||||
setattr(self, 'mlp_ln', LayerNormANE(n_state))
|
||||
)
|
||||
self.mlp_ln = LayerNormANE(n_state)
|
||||
|
||||
|
||||
class AudioEncoderANE(AudioEncoder):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
||||
|
||||
setattr(self, 'blocks', nn.ModuleList(
|
||||
self.blocks = nn.ModuleList(
|
||||
[ResidualAttentionBlockANE(n_state, n_head) for _ in range(n_layer)]
|
||||
))
|
||||
setattr(self, 'ln_post', LayerNormANE(n_state))
|
||||
)
|
||||
self.ln_post = LayerNormANE(n_state)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
@ -168,10 +165,10 @@ class TextDecoderANE(TextDecoder):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
||||
|
||||
setattr(self, 'blocks', nn.ModuleList(
|
||||
self.blocks= nn.ModuleList(
|
||||
[ResidualAttentionBlockANE(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
))
|
||||
setattr(self, 'ln', LayerNormANE(n_state))
|
||||
)
|
||||
self.ln= LayerNormANE(n_state)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
@ -213,20 +210,20 @@ class WhisperANE(Whisper):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__(dims)
|
||||
|
||||
setattr(self, 'encoder', AudioEncoderANE(
|
||||
self.encoder = AudioEncoderANE(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
))
|
||||
setattr(self, 'decoder', TextDecoderANE(
|
||||
)
|
||||
self.decoder = TextDecoderANE(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
))
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
||||
|
||||
|
Reference in New Issue
Block a user