File size: 3,178 Bytes
2cc87ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union  # type: ignore[import-not-found]

from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor


class SpanEmbeddings(Embeddings, ABC):
    """Interface for models that embed text spans within documents."""

    @abstractmethod
    def embed_document_spans(
        self,
        texts: list[str],
        starts: Union[list[int], List[List[int]]],
        ends: Union[list[int], List[List[int]]],
    ) -> list[Optional[list[float]]]:
        """Embed search docs.

        Args:
            texts: List of text to embed.
            starts: List of start indices or list of lists of start indices (multi-span).
            ends: List of end indices or list of lists of end indices (multi-span).

        Returns:
            List of embeddings.
        """

    @abstractmethod
    def embed_query_span(
        self, text: str, start: Union[int, list[int]], end: Union[int, list[int]]
    ) -> Optional[list[float]]:
        """Embed query text.

        Args:
            text: Text to embed.
            start: Start index or list of start indices (multi-span).
            end: End index or list of end indices (multi-span).

        Returns:
            Embedding.
        """

    def embed_documents(self, texts: list[str]) -> list[Optional[list[float]]]:
        """Embed search docs.

        Args:
            texts: List of text to embed.

        Returns:
            List of embeddings.
        """
        return self.embed_document_spans(texts, [0] * len(texts), [len(text) for text in texts])

    def embed_query(self, text: str) -> Optional[list[float]]:
        """Embed query text.

        Args:
            text: Text to embed.

        Returns:
            Embedding.
        """
        return self.embed_query_span(text, 0, len(text))

    async def aembed_document_spans(
        self,
        texts: list[str],
        starts: Union[list[int], list[list[int]]],
        ends: Union[list[int], list[list[int]]],
    ) -> list[Optional[list[float]]]:
        """Asynchronous Embed search docs.

        Args:
            texts: List of text to embed.
            starts: List of start indices or list of lists of start indices (multi-span).
            ends: List of end indices or list of lists of end indices (multi-span).

        Returns:
            List of embeddings.
        """
        return await run_in_executor(None, self.embed_document_spans, texts, starts, ends)

    async def aembed_query_spans(
        self, text: str, start: Union[int, list[int]], end: Union[int, list[int]]
    ) -> Optional[list[float]]:
        """Asynchronous Embed query text.

        Args:
            text: Text to embed.
            start: Start index or list of start indices (multi-span).
            end: End index or list of end indices (multi-span).

        Returns:
            Embedding.
        """
        return await run_in_executor(None, self.embed_query_span, text, start, end)

    @property
    @abstractmethod
    def embedding_dim(self) -> int:
        """Get the embedding dimension."""
        ...