|
import dataclasses |
|
import json |
|
import warnings |
|
from dataclasses import dataclass, MISSING |
|
from functools import partial |
|
from typing import Optional, Any |
|
|
|
|
|
@partial(dataclass, frozen=True, kw_only=True) |
|
class JsonComparable: |
|
def to_json(self) -> str: |
|
return json.dumps(dataclasses.asdict(self)) |
|
|
|
def __eq__(self, other: "JsonComparable") -> bool: |
|
return self.to_json() == other.to_json() |
|
|
|
def __hash__(self) -> int: |
|
return hash(self.to_json()) |
|
|
|
def __lt__(self, other: "JsonComparable") -> bool: |
|
return self.to_json() < other.to_json() |
|
|
|
|
|
@partial(dataclass, frozen=True, kw_only=True) |
|
class SubblockConfig(JsonComparable): |
|
no_op: bool = False |
|
replace_with_linear: bool = False |
|
sparsify: Optional[list[str]] = None |
|
|
|
def __post_init__(self): |
|
assert not (self.no_op and self.replace_with_linear) |
|
|
|
def _force_setattr(self, name: str, value: Any) -> None: |
|
""" |
|
Set an attribute even in frozen dataclasses. |
|
Use only inside __post_init__! |
|
""" |
|
object.__setattr__(self, name, value) |
|
|
|
|
|
@partial(dataclass, frozen=True, kw_only=True) |
|
class AttentionConfig(SubblockConfig): |
|
n_heads_in_group: Optional[int] = None |
|
window_length: Optional[int] = None |
|
num_sink_tokens: Optional[int] = None |
|
use_prefill_window_in_sink_attention: bool = False |
|
unshifted_sink: bool = False |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
assert not (self.no_op and self.replace_with_linear) |
|
|
|
if self.no_op or self.replace_with_linear: |
|
for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]: |
|
self._force_setattr(irrelevant_att, None) |
|
else: |
|
assert self.n_heads_in_group is not None |
|
|
|
if self.is_sink: |
|
assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \ |
|
("Unshifted sink uses its own kind of explicit masking, not standard window. " |
|
"Set use_prefill_window_in_sink_attention to False.") |
|
assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \ |
|
"Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True" |
|
|
|
@property |
|
def prefill_sliding_window(self) -> Optional[int]: |
|
if self.window_length is not None: |
|
if not self.is_sink or self.use_prefill_window_in_sink_attention: |
|
return self.window_length |
|
return None |
|
|
|
@property |
|
def is_sliding(self) -> bool: |
|
return self.prefill_sliding_window is not None |
|
|
|
@property |
|
def is_sink(self) -> bool: |
|
return ( |
|
(self.window_length is not None) |
|
and |
|
(self.num_sink_tokens is not None) |
|
) |
|
|
|
|
|
@partial(dataclass, frozen=True, kw_only=True) |
|
class FFNConfig(SubblockConfig): |
|
ffn_mult: Optional[float] = None |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
if self.no_op or self.replace_with_linear: |
|
self._force_setattr("ffn_mult", None) |
|
else: |
|
assert self.ffn_mult is not None |
|
self._force_setattr("ffn_mult", round(self.ffn_mult, 6)) |
|
|
|
|
|
@partial(dataclass, frozen=True, kw_only=True) |
|
class BlockConfig(JsonComparable): |
|
attention: AttentionConfig = MISSING |
|
ffn: FFNConfig = MISSING |
|
|
|
def __post_init__(self): |
|
""" |
|
Init subblock dataclasses from dicts |
|
""" |
|
for subblock_name in dataclasses.fields(self): |
|
subblock_config = getattr(self, subblock_name.name) |
|
if isinstance(subblock_config, dict): |
|
subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)] |
|
unsupported_fields = [field_name for field_name in subblock_config.keys() |
|
if field_name not in subblock_fields] |
|
if len(unsupported_fields) > 0: |
|
warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}") |
|
subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields} |
|
object.__setattr__(self, subblock_name.name, |
|
subblock_name.type(**subblock_config)) |
|
|