Skip to content

Commit

Permalink
Add Index class
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 20, 2023
1 parent 338eb56 commit 18cecff
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions outlines/index/index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
from typing import Callable, NamedTuple, NewType
from dataclasses import dataclass
from typing import NewType, Protocol, Union

import torch

State = NewType("State", int)


class Index(NamedTuple):
next_instruction: Callable[[State], torch.Tensor]
next_state: Callable[[State, torch.Tensor], State]
is_final: Callable[[State], bool]
@dataclass(frozen=True)
class GenerateInstruction:
logits_mask: str
temperature: float
top_k: int
top_p: int


@dataclass(frozen=True)
class FillInstruction:
token_ids: int


FSMInstruction = Union[GenerateInstruction, FillInstruction]


class Index(Protocol):
def next_instruction(self, state: State) -> FSMInstruction:
...

def next_state(self, state: State, token_id: torch.Tensor) -> State:
...

def is_final_state(self, state: State) -> bool:
...

0 comments on commit 18cecff

Please sign in to comment.