Skip to content

Commit

Permalink
Remove the "ShieldType" concept (#430)
Browse files Browse the repository at this point in the history
# What does this PR do?

This PR kills the notion of "ShieldType". The impetus for this is the
realization:

> Why is keyword llama-guard appearing so many times everywhere,
sometimes with hyphens, sometimes with underscores?

Now that we have a notion of "provider specific resource identifiers"
and "user specific aliases" for those and the fact that this works with
models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can
follow the same rules for Shields.

So each Safety provider can make up a notion of identifiers it has
registered. This already happens with Bedrock correctly. We just
generalize it for Llama Guard, Prompt Guard, etc.

For Llama Guard, we further simplify by just adopting the underlying
model name itself as the identifier! No confusion necessary.

While doing this, I noticed a bug in our DistributionRegistry where we
weren't scoping identifiers by type. Fixed.

## Feature/Issue validation/testing/test plan

Ran (inference, safety, memory, agents) tests with ollama and fireworks
providers.
  • Loading branch information
ashwinb authored Nov 12, 2024
1 parent 09269e2 commit 983d6ce
Show file tree
Hide file tree
Showing 26 changed files with 147 additions and 206 deletions.
3 changes: 2 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant

- [ ] Addresses issue (#issue)

## Feature/Issue validation/testing/test plan

## Test Plan

Please describe:
- tests you ran to verify your changes with result summaries.
Expand Down
4 changes: 2 additions & 2 deletions docs/_deprecating_soon.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[\"llama_guard\"],\n",
" output_shields=[\"llama_guard\"],\n",
" input_shields=[\"Llama-Guard-3-1B\"],\n",
" output_shields=[\"Llama-Guard-3-1B\"],\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
Expand Down
62 changes: 20 additions & 42 deletions docs/resources/llama-stack-spec.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
"description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
},
"servers": [
{
Expand Down Expand Up @@ -5743,9 +5743,6 @@
"const": "shield",
"default": "shield"
},
"shield_type": {
"$ref": "#/components/schemas/ShieldType"
},
"params": {
"type": "object",
"additionalProperties": {
Expand Down Expand Up @@ -5777,20 +5774,10 @@
"identifier",
"provider_resource_id",
"provider_id",
"type",
"shield_type"
"type"
],
"title": "A safety shield resource that can be used to check content"
},
"ShieldType": {
"type": "string",
"enum": [
"generic_content_shield",
"llama_guard",
"code_scanner",
"prompt_guard"
]
},
"Trace": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -7262,9 +7249,6 @@
"shield_id": {
"type": "string"
},
"shield_type": {
"$ref": "#/components/schemas/ShieldType"
},
"provider_shield_id": {
"type": "string"
},
Expand Down Expand Up @@ -7299,8 +7283,7 @@
},
"additionalProperties": false,
"required": [
"shield_id",
"shield_type"
"shield_id"
]
},
"RunEvalRequest": {
Expand Down Expand Up @@ -7854,58 +7837,58 @@
],
"tags": [
{
"name": "Inference"
"name": "MemoryBanks"
},
{
"name": "BatchInference"
},
{
"name": "Agents"
},
{
"name": "Telemetry"
"name": "Inference"
},
{
"name": "Eval"
"name": "DatasetIO"
},
{
"name": "Models"
"name": "Eval"
},
{
"name": "Inspect"
"name": "Models"
},
{
"name": "EvalTasks"
"name": "PostTraining"
},
{
"name": "ScoringFunctions"
},
{
"name": "Memory"
"name": "Datasets"
},
{
"name": "Safety"
"name": "Shields"
},
{
"name": "DatasetIO"
"name": "Telemetry"
},
{
"name": "MemoryBanks"
"name": "Inspect"
},
{
"name": "Shields"
"name": "Safety"
},
{
"name": "PostTraining"
"name": "SyntheticDataGeneration"
},
{
"name": "Datasets"
"name": "Memory"
},
{
"name": "Scoring"
},
{
"name": "SyntheticDataGeneration"
},
{
"name": "BatchInference"
"name": "EvalTasks"
},
{
"name": "BuiltinTool",
Expand Down Expand Up @@ -8255,10 +8238,6 @@
"name": "Shield",
"description": "A safety shield resource that can be used to check content\n\n<SchemaDefinition schemaRef=\"#/components/schemas/Shield\" />"
},
{
"name": "ShieldType",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ShieldType\" />"
},
{
"name": "Trace",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Trace\" />"
Expand Down Expand Up @@ -8614,7 +8593,6 @@
"Session",
"Shield",
"ShieldCallStep",
"ShieldType",
"SpanEndPayload",
"SpanStartPayload",
"SpanStatus",
Expand Down
42 changes: 13 additions & 29 deletions docs/resources/llama-stack-spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2227,11 +2227,8 @@ components:
type: string
shield_id:
type: string
shield_type:
$ref: '#/components/schemas/ShieldType'
required:
- shield_id
- shield_type
type: object
RestAPIExecutionConfig:
additionalProperties: false
Expand Down Expand Up @@ -2698,8 +2695,6 @@ components:
type: string
provider_resource_id:
type: string
shield_type:
$ref: '#/components/schemas/ShieldType'
type:
const: shield
default: shield
Expand All @@ -2709,7 +2704,6 @@ components:
- provider_resource_id
- provider_id
- type
- shield_type
title: A safety shield resource that can be used to check content
type: object
ShieldCallStep:
Expand All @@ -2736,13 +2730,6 @@ components:
- step_id
- step_type
type: object
ShieldType:
enum:
- generic_content_shield
- llama_guard
- code_scanner
- prompt_guard
type: string
SpanEndPayload:
additionalProperties: false
properties:
Expand Down Expand Up @@ -3397,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
\ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
\ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
Expand Down Expand Up @@ -4761,24 +4748,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
- name: Inference
- name: MemoryBanks
- name: BatchInference
- name: Agents
- name: Telemetry
- name: Inference
- name: DatasetIO
- name: Eval
- name: Models
- name: Inspect
- name: EvalTasks
- name: ScoringFunctions
- name: Memory
- name: Safety
- name: DatasetIO
- name: MemoryBanks
- name: Shields
- name: PostTraining
- name: ScoringFunctions
- name: Datasets
- name: Scoring
- name: Shields
- name: Telemetry
- name: Inspect
- name: Safety
- name: SyntheticDataGeneration
- name: BatchInference
- name: Memory
- name: Scoring
- name: EvalTasks
- description: <SchemaDefinition schemaRef="#/components/schemas/BuiltinTool" />
name: BuiltinTool
- description: <SchemaDefinition schemaRef="#/components/schemas/CompletionMessage"
Expand Down Expand Up @@ -5046,8 +5033,6 @@ tags:
<SchemaDefinition schemaRef="#/components/schemas/Shield" />'
name: Shield
- description: <SchemaDefinition schemaRef="#/components/schemas/ShieldType" />
name: ShieldType
- description: <SchemaDefinition schemaRef="#/components/schemas/Trace" />
name: Trace
- description: 'Checkpoint created during training runs
Expand Down Expand Up @@ -5343,7 +5328,6 @@ x-tagGroups:
- Session
- Shield
- ShieldCallStep
- ShieldType
- SpanEndPayload
- SpanStartPayload
- SpanStatus
Expand Down
6 changes: 3 additions & 3 deletions docs/zero_to_hero_guide/06_Safety101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@
" pass\n",
"\n",
" async def run_shield(\n",
" self, shield_type: str, messages: List[dict]\n",
" self, shield_id: str, messages: List[dict]\n",
" ) -> RunShieldResponse:\n",
" async with httpx.AsyncClient() as client:\n",
" response = await client.post(\n",
" f\"{self.base_url}/safety/run_shield\",\n",
" json=dict(\n",
" shield_type=shield_type,\n",
" shield_id=shield_id,\n",
" messages=[encodable_dict(m) for m in messages],\n",
" ),\n",
" headers={\n",
Expand Down Expand Up @@ -216,7 +216,7 @@
" ]:\n",
" cprint(f\"User>{message['content']}\", \"green\")\n",
" response = await client.run_shield(\n",
" shield_type=\"llama_guard\",\n",
" shield_id=\"Llama-Guard-3-1B\",\n",
" messages=[message],\n",
" )\n",
" print(response)\n",
Expand Down
1 change: 0 additions & 1 deletion llama_stack/apis/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def provider_dataset_id(self) -> str:
return self.provider_resource_id


@json_schema_type
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
Expand Down
1 change: 0 additions & 1 deletion llama_stack/apis/eval_tasks/eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def provider_eval_task_id(self) -> str:
return self.provider_resource_id


@json_schema_type
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None
Expand Down
1 change: 0 additions & 1 deletion llama_stack/apis/memory_banks/memory_banks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ class GraphMemoryBank(MemoryBankResourceMixin):
]


@json_schema_type
class MemoryBankInput(BaseModel):
memory_bank_id: str
params: BankParams
Expand Down
1 change: 0 additions & 1 deletion llama_stack/apis/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def provider_model_id(self) -> str:
return self.provider_resource_id


@json_schema_type
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions llama_stack/apis/safety/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:


def encodable_dict(d: BaseModel):
return json.loads(d.json())
return json.loads(d.model_dump_json())


class SafetyClient(Safety):
Expand Down Expand Up @@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
)
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
shield_id="llama_guard",
shield_id="Llama-Guard-3-1B",
messages=[message],
)
print(response)
Expand Down
1 change: 0 additions & 1 deletion llama_stack/apis/scoring_functions/scoring_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id


@json_schema_type
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
Expand Down
6 changes: 2 additions & 4 deletions llama_stack/apis/shields/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ async def list_shields(self) -> List[Shield]:
async def register_shield(
self,
shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str],
provider_id: Optional[str],
params: Optional[Dict[str, Any]],
Expand All @@ -47,7 +46,6 @@ async def register_shield(
f"{self.base_url}/shields/register",
json={
"shield_id": shield_id,
"shield_type": shield_type,
"provider_shield_id": provider_shield_id,
"provider_id": provider_id,
"params": params,
Expand All @@ -56,12 +54,12 @@ async def register_shield(
)
response.raise_for_status()

async def get_shield(self, shield_type: str) -> Optional[Shield]:
async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
params={
"shield_type": shield_type,
"shield_id": shield_id,
},
headers={"Content-Type": "application/json"},
)
Expand Down
Loading

0 comments on commit 983d6ce

Please sign in to comment.