Skip to content

Commit

Permalink
tmp change - skip logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed Jul 25, 2024
1 parent 5d9efb8 commit e32948b
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions tests/contrib/langchain/test_langchain_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import sys

import langchain
import mock
import pytest

from ddtrace import patch
from ddtrace.contrib.langchain.patch import PATCH_LANGCHAIN_V0
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs import LLMObs
from tests.contrib.langchain.utils import get_request_vcr
from tests.contrib.langchain.utils import long_input_text
Expand All @@ -17,8 +18,9 @@
from tests.subprocesstest import run_in_subprocess
from tests.utils import flaky

LANGCHAIN_VERSION = parse_version(langchain.__version__) < (0, 1, 0) or sys.version_info < (3, 10),

if PATCH_LANGCHAIN_V0:
if LANGCHAIN_VERSION < (0, 1, 0):
from langchain.schema import AIMessage
from langchain.schema import ChatMessage
from langchain.schema import HumanMessage
Expand Down Expand Up @@ -88,7 +90,7 @@ class BaseTestLLMObsLangchain:
def _invoke_llm(cls, llm, prompt, mock_tracer, cassette_name):
LLMObs.enable(ml_app=cls.ml_app, integrations_enabled=False, _tracer=mock_tracer)
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
if PATCH_LANGCHAIN_V0:
if LANGCHAIN_VERSION < (0, 1, 0):
llm(prompt)
else:
llm.invoke(prompt)
Expand All @@ -103,7 +105,7 @@ def _invoke_chat(cls, chat_model, prompt, mock_tracer, cassette_name, role="user
messages = [HumanMessage(content=prompt)]
else:
messages = [ChatMessage(content=prompt, role="custom")]
if PATCH_LANGCHAIN_V0:
if LANGCHAIN_VERSION < (0, 1, 0):
chat_model(messages)
else:
chat_model.invoke(messages)
Expand All @@ -116,7 +118,7 @@ def _invoke_chain(cls, chain, prompt, mock_tracer, cassette_name, batch=False):
with get_request_vcr(subdirectory_name=cls.cassette_subdirectory_name).use_cassette(cassette_name):
if batch:
chain.batch(inputs=prompt)
elif PATCH_LANGCHAIN_V0:
elif LANGCHAIN_VERSION < (0, 1, 0):
chain.run(prompt)
else:
chain.invoke(prompt)
Expand Down Expand Up @@ -144,7 +146,7 @@ def _embed_documents(cls, embedding_model, documents, mock_tracer, cassette_name
return mock_tracer.pop_traces()[0]


@pytest.mark.skipif(not PATCH_LANGCHAIN_V0, reason="These tests are for langchain < 0.1.0")
@pytest.mark.skipif(LANGCHAIN_VERSION >= (0, 1, 0), reason="These tests are for langchain < 0.1.0")
class TestLLMObsLangchain(BaseTestLLMObsLangchain):
cassette_subdirectory_name = "langchain"

Expand Down Expand Up @@ -398,7 +400,7 @@ def test_llmobs_embedding_documents(self, langchain, mock_llmobs_span_writer, mo


@flaky(1735812000, reason="Community cassette tests are flaky")
@pytest.mark.skipif(PATCH_LANGCHAIN_V0, reason="These tests are for langchain >= 0.1.0")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0")
class TestLLMObsLangchainCommunity(BaseTestLLMObsLangchain):
cassette_subdirectory_name = "langchain_community"

Expand Down Expand Up @@ -648,7 +650,7 @@ def test_llmobs_embedding_documents(


@flaky(1735812000, reason="Community cassette tests are flaky")
@pytest.mark.skipif(PATCH_LANGCHAIN_V0, reason="These tests are for langchain >= 0.1.0")
@pytest.mark.skipif(LANGCHAIN_VERSION < (0, 1, 0), reason="These tests are for langchain >= 0.1.0")
class TestLangchainTraceStructureWithLlmIntegrations(SubprocessTestCase):
bedrock_env_config = dict(
AWS_ACCESS_KEY_ID="testing",
Expand Down

0 comments on commit e32948b

Please sign in to comment.