from __future__ import annotations
import json
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Union,
)
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
if TYPE_CHECKING:
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
from xinference.model.llm.core import LlamaCppGenerateConfig
[docs]
class Xinference(LLM):
"""`Xinference` large-scale model inference service.
To use, you should have the xinference library installed:
.. code-block:: bash
pip install "xinference[all]"
If you're simply using the services provided by Xinference, you can utilize the xinference_client package:
.. code-block:: bash
pip install xinference_client
Check out: https://github.com/xorbitsai/inference
To run, you need to start a Xinference supervisor on one server and Xinference workers on the other servers
Example:
To start a local instance of Xinference, run
.. code-block:: bash
$ xinference
You can also deploy Xinference in a distributed cluster. Here are the steps:
Starting the supervisor:
.. code-block:: bash
$ xinference-supervisor
Starting the worker:
.. code-block:: bash
$ xinference-worker
Then, launch a model using command line interface (CLI).
Example:
.. code-block:: bash
$ xinference launch -n orca -s 3 -q q4_0
It will return a model UID. Then, you can use Xinference with LangChain.
Example:
.. code-block:: python
from langchain_community.llms import Xinference
llm = Xinference(
server_url="http://0.0.0.0:9997",
model_uid = {model_uid} # replace model_uid with the model UID return from launching the model
)
llm.invoke(
prompt="Q: where can we visit in the capital of France? A:",
generate_config={"max_tokens": 1024, "stream": True},
)
Example:
.. code-block:: python
from langchain_community.llms import Xinference
from langchain.prompts import PromptTemplate
llm = Xinference(
server_url="http://0.0.0.0:9997",
model_uid={model_uid}, # replace model_uid with the model UID return from launching the model
stream=True
)
prompt = PromptTemplate(
input=['country'],
template="Q: where can we visit in the capital of {country}? A:"
)
chain = prompt | llm
chain.stream(input={'country': 'France'})
To view all the supported builtin models, run:
.. code-block:: bash
$ xinference list --all
""" # noqa: E501
client: Optional[Any] = None
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
model_kwargs: Dict[str, Any]
"""Keyword arguments to be passed to xinference.LLM"""
def __init__(
self,
server_url: Optional[str] = None,
model_uid: Optional[str] = None,
api_key: Optional[str] = None,
**model_kwargs: Any,
):
try:
from xinference.client import RESTfulClient
except ImportError:
try:
from xinference_client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference` or `pip install xinference_client`."
) from e
model_kwargs = model_kwargs or {}
super().__init__(
**{ # type: ignore[arg-type]
"server_url": server_url,
"model_uid": model_uid,
"model_kwargs": model_kwargs,
}
)
if self.server_url is None:
raise ValueError("Please provide server URL")
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None and self._cluster_authed:
self._headers["Authorization"] = f"Bearer {api_key}"
self.client = RESTfulClient(server_url, api_key)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
**{"model_kwargs": self.model_kwargs},
}
def _check_cluster_authenticated(self) -> None:
url = f"{self.server_url}/v1/cluster/auth"
response = requests.get(url)
if response.status_code == 404:
self._cluster_authed = False
else:
if response.status_code != 200:
raise RuntimeError(
f"Failed to get cluster information, "
f"detail: {response.json()['detail']}"
)
response_data = response.json()
self._cluster_authed = bool(response_data["auth"])
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the xinference model and return the output.
Args:
prompt: The prompt to use for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Returns:
The generated string by the model.
"""
if self.client is None:
raise ValueError("Client is not initialized!")
model = self.client.get_model(self.model_uid)
generate_config: "LlamaCppGenerateConfig" = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
if generate_config and generate_config.get("stream"):
combined_text_output = ""
for token in self._stream_generate(
model=model,
prompt=prompt,
run_manager=run_manager,
generate_config=generate_config,
):
combined_text_output += token
return combined_text_output
else:
completion = model.generate(prompt=prompt, generate_config=generate_config)
return completion["choices"][0]["text"]
def _stream_generate(
self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle"],
prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional["LlamaCppGenerateConfig"] = None,
) -> Generator[str, None, None]:
"""
Args:
prompt: The prompt to use for generation.
model: The model used for generation.
stop: Optional list of stop words to use when generating.
generate_config: Optional dictionary for the configuration used for
generation.
Yields:
A string token.
"""
streaming_response = model.generate(
prompt=prompt, generate_config=generate_config
)
for chunk in streaming_response:
if isinstance(chunk, dict):
choices = chunk.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
log_probs = choice.get("logprobs")
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
for stream_resp in self._create_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
def _create_generate_stream(
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
) -> Iterator[str]:
if self.client is None:
raise ValueError("Client is not initialized!")
model = self.client.get_model(self.model_uid)
yield from model.generate(prompt=prompt, generate_config=generate_config)
@staticmethod
def _stream_response_to_generation_chunk(
stream_response: str,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
token = ""
if isinstance(stream_response, dict):
choices = stream_response.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")
return GenerationChunk(
text=token,
generation_info=dict(
finish_reason=choice.get("finish_reason", None),
logprobs=choice.get("logprobs", None),
),
)
else:
raise TypeError("choice type error!")
else:
return GenerationChunk(text=token)
else:
raise TypeError("stream_response type error!")
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
async for stream_resp in self._acreate_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk
async def _acreate_generate_stream(
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
) -> AsyncIterator[str]:
request_body: Dict[str, Any] = {"model": self.model_uid, "prompt": prompt}
if generate_config is not None:
for key, value in generate_config.items():
request_body[key] = value
stream = bool(generate_config and generate_config.get("stream"))
async with aiohttp.ClientSession() as session:
async with session.post(
url=f"{self.server_url}/v1/completions",
json=request_body,
) as response:
if response.status != 200:
if response.status == 404:
raise FileNotFoundError(
"astream call failed with status code 404."
)
else:
optional_detail = response.text
raise ValueError(
f"astream call failed with status code {response.status}."
f" Details: {optional_detail}"
)
async for line in response.content:
if not stream:
yield json.loads(line)
else:
json_str = line.decode("utf-8")
if line.startswith(b"data:"):
json_str = json_str[len(b"data:") :].strip()
if not json_str:
continue
yield json.loads(json_str)