Skip to content

Commit

Permalink
fix: score for jupyter (#1411)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Oct 2, 2024
1 parent 167641e commit 4aa4315
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions src/ragas/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class Metric(ABC):

@property
@abstractmethod
def name(self) -> str: ...
def name(self) -> str:
...

@property
def required_columns(self) -> t.Dict[str, t.Set[str]]:
Expand Down Expand Up @@ -145,7 +146,8 @@ async def ascore(
return score

@abstractmethod
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: ...
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
...


@dataclass
Expand Down Expand Up @@ -193,6 +195,15 @@ def single_turn_score(
self.name, inputs=sample.model_dump(), callbacks=callbacks
)
try:
if is_event_loop_running():
try:
import nest_asyncio

nest_asyncio.apply()
except ImportError:
raise ImportError(
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
)
loop = asyncio.get_event_loop()
score = loop.run_until_complete(
self._single_turn_ascore(sample=sample, callbacks=group_cm)
Expand Down Expand Up @@ -234,7 +245,8 @@ async def _single_turn_ascore(
self,
sample: SingleTurnSample,
callbacks: Callbacks,
) -> float: ...
) -> float:
...


class MultiTurnMetric(Metric):
Expand All @@ -248,6 +260,15 @@ def multi_turn_score(
self.name, inputs=sample.model_dump(), callbacks=callbacks
)
try:
if is_event_loop_running():
try:
import nest_asyncio

nest_asyncio.apply()
except ImportError:
raise ImportError(
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
)
loop = asyncio.get_event_loop()
score = loop.run_until_complete(
self._multi_turn_ascore(sample=sample, callbacks=group_cm)
Expand Down Expand Up @@ -290,7 +311,8 @@ async def _multi_turn_ascore(
self,
sample: MultiTurnSample,
callbacks: Callbacks,
) -> float: ...
) -> float:
...


class Ensember:
Expand Down

0 comments on commit 4aa4315

Please sign in to comment.