Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for individual model permissions #1631

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ ADMIN_API_SECRET=# secret to admin API calls, like computing usage stats or expo
# These values cannot be updated at runtime
# They need to be passed when building the docker image
# See https://github.com/huggingface/chat-ui/main/.github/workflows/deploy-prod.yml#L44-L47
APP_BASE="" # base path of the app, e.g. /chat, left blank as default
APP_BASE= # base path of the app, e.g. /chat
PUBLIC_APP_COLOR=blue # can be any of tailwind colors: https://tailwindcss.com/docs/customizing-colors#default-color-palette
### Body size limit for SvelteKit https://svelte.dev/docs/kit/adapter-node#Environment-variables-BODY_SIZE_LIMIT
BODY_SIZE_LIMIT=15728640
Expand Down
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ OPENID_CONFIG=`{
SCOPES: "openid profile",
TOLERANCE: // optional
RESOURCE: // optional
PROVIDER: // required only for group-based permissions
}`
```

Expand Down Expand Up @@ -337,6 +338,47 @@ We currently support [IDEFICS](https://huggingface.co/blog/idefics) (hosted on T
}
```

#### Group-based Model Permissions

If [logging in with OpenID](#openid-connect) via a supported provider, then user groups can be used in combination with the `allowed_groups` field for each model to show/hide models to users based on their group membership.

For all providers, see the following. Then, see additional instructions for your provider below.

1. Add `PROVIDER: "<provider-name-here>"` to your `.env.local`. Also, add `groups` to the `OPENID_CONFIG.SCOPES` field in your `.env.local` file:

```env
OPENID_CONFIG=`{
// rest of OPENID_CONFIG here
PROVIDER: "<provider-name-here>",
SCOPES: "openid profile groups",
// rest of OPENID_CONFIG here
}`
```

2. Use the `allowed_groups` parameter for each model to specify which group(s) should have access to that model. If not specified, all users will be able to access the model.

> [!WARNING]
> The first model in your `.env.local` file is considered the "default" model and should be available to all users, so we strongly recommend against setting `allowed_groups` for this model.

#### Provider: Microsoft Entra

In order to enable use of [Microsoft Entra Security Groups](https://learn.microsoft.com/en-us/entra/fundamentals/concept-learn-about-groups) to show/hide models, do the following:

1. Replace `<provider-name-here>` with `entra` in `.env.local`.

2. `allowed_groups` for each model in `.env.local` should be a list of Microsoft Entra **Group IDs** (not group names), e.g.:

```env
{
// rest of the model config here
"allowed_groups": ["123abcde-1234-abcd-cdef-1234567890ab", "abcde123-abcd-1234-cdef-abcdef123456"]
}
```

3. Finally, configure your app in Microsoft Entra so that the app can access user groups via the MS Graph API:
- [Add groups claim](https://learn.microsoft.com/en-gb/entra/identity-platform/optional-claims?tabs=appui#configure-groups-optional-claims) to your app
- [Enable ID Tokens](https://learn.microsoft.com/en-us/entra/identity-platform/v2-protocols-oidc#enable-id-tokens) for your app

#### Running your own models using a custom endpoint

If you want to, instead of hitting models on the Hugging Face Inference API, you can run your own models locally.
Expand Down
2 changes: 1 addition & 1 deletion src/app.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ declare global {
// interface Error {}
interface Locals {
sessionId: string;
user?: User & { logoutDisabled?: boolean };
user?: User & { logoutDisabled?: boolean; groups?: string[] };
}

interface Error {
Expand Down
58 changes: 57 additions & 1 deletion src/hooks.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { sha256 } from "$lib/utils/sha256";
import { addWeeks } from "date-fns";
import { checkAndRunMigrations } from "$lib/migrations/migrations";
import { building } from "$app/environment";
import { logout, OIDConfig, ProviderCookieNames } from "$lib/server/auth";
import { type AccessToken, providers } from "$lib/server/providers/providers";
import { logger } from "$lib/server/logger";
import { AbortedGenerations } from "$lib/server/abortedGenerations";
import { MetricsServer } from "$lib/server/metrics";
Expand Down Expand Up @@ -229,7 +231,12 @@ export const handle: Handle = async ({ event, resolve }) => {
...(envPublic.PUBLIC_ORIGIN ? [new URL(envPublic.PUBLIC_ORIGIN).host] : []),
];

if (!validOrigins.includes(new URL(origin).host)) {
// origin is null when the POST request callback comes from an auth provider like MS entra
// so we skip this check (CSRF token is still validated)
if (
event.url.pathname !== `${base}/login/callback` &&
!validOrigins.includes(new URL(origin).host)
) {
return errorResponse(403, "Invalid referer for POST request");
}
}
Expand Down Expand Up @@ -278,6 +285,55 @@ export const handle: Handle = async ({ event, resolve }) => {
}
}

// Get user groups for allowed models
if (OIDConfig.PROVIDER && OIDConfig.SCOPES.includes("groups")) {
const provider = providers[OIDConfig.PROVIDER];
const session_exists = event.cookies.get(env.COOKIE_NAME) !== undefined;

let accessToken: AccessToken = JSON.parse(
event.cookies.get(ProviderCookieNames.ACCESS_TOKEN)?.toString() || "{}"
);
let providerParameters = JSON.parse(
event.cookies.get(ProviderCookieNames.PROVIDER_PARAMS)?.toString() || "{}"
);

// If user is logged in, get/refresh access token and use it to retrieve user groups
if (event.locals.user) {
// Get access token upon login with id token
if (accessToken && providerParameters.idToken) {
[accessToken, providerParameters] = await provider.getAccessToken(
event.cookies,
providerParameters
);
event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters);
}
// Refresh access token on subsequent requests
else if (accessToken.refreshToken && providerParameters.userTid) {
accessToken = await provider.refreshAccessToken(
event.cookies,
accessToken,
providerParameters
);
event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters);
}
// Logout user automatically if session exists but access token and/or provider params cookies have expired
else if (session_exists) {
event.locals.user.groups = undefined;
await logout(event.cookies, event.locals);
}
}
} else if (OIDConfig.SCOPES.includes("groups")) {
return errorResponse(
500,
"'groups' has been set in OPENID_CONFIG.SCOPES, but OPENID_CONFIG.PROVIDER is undefined in .env file"
);
} else if (OIDConfig.PROVIDER) {
return errorResponse(
500,
"OPENID_CONFIG.PROVIDER has been set, but 'groups' scope not set in OPENID_CONFIG.SCOPES in .env file"
);
}

let replaced = false;

const response = await resolve(event, {
Expand Down
2 changes: 1 addition & 1 deletion src/lib/components/AssistantSettings.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
bind:value={modelId}
>
{#each models.filter((model) => !model.unlisted) as model}
{#each models as model}
<option value={model.id}>{model.displayName}</option>
{/each}
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
Expand Down
3 changes: 1 addition & 2 deletions src/lib/components/NavMenu.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import NavConversationItem from "./NavConversationItem.svelte";
import type { LayoutData } from "../../routes/$types";
import type { ConvSidebar } from "$lib/types/ConvSidebar";
import type { Model } from "$lib/types/Model";
import { page } from "$app/stores";

export let conversations: ConvSidebar[];
Expand Down Expand Up @@ -43,7 +42,7 @@
older: "Older",
} as const;

const nModels: number = $page.data.models.filter((el: Model) => !el.unlisted).length;
const nModels: number = $page.data.models.length;
</script>

<div class="sticky top-0 flex flex-none items-center justify-between px-1.5 py-3.5 max-sm:pt-0">
Expand Down
51 changes: 48 additions & 3 deletions src/lib/server/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ import {
type TokenSet,
custom,
} from "openid-client";
import { redirect } from "@sveltejs/kit";
import { addHours, addWeeks } from "date-fns";
import { env } from "$env/dynamic/private";
import { sha256 } from "$lib/utils/sha256";
import { z } from "zod";
import { dev } from "$app/environment";
import { base } from "$app/paths";
import type { Cookies } from "@sveltejs/kit";
import { collections } from "$lib/server/database";
import JSON5 from "json5";
import { logger } from "$lib/server/logger";

export interface OIDCSettings {
redirectURI: string;
response_type?: string;
response_mode?: string | undefined;
nonce?: string | undefined;
}

export interface OIDCUserInfo {
Expand All @@ -34,6 +39,7 @@ export const OIDConfig = z
.object({
CLIENT_ID: stringWithDefault(env.OPENID_CLIENT_ID),
CLIENT_SECRET: stringWithDefault(env.OPENID_CLIENT_SECRET),
PROVIDER: stringWithDefault(env.OPENID_PROVIDER || ""),
PROVIDER_URL: stringWithDefault(env.OPENID_PROVIDER_URL),
SCOPES: stringWithDefault(env.OPENID_SCOPES),
NAME_CLAIM: stringWithDefault(env.OPENID_NAME_CLAIM).refine(
Expand All @@ -46,8 +52,15 @@ export const OIDConfig = z
})
.parse(JSON5.parse(env.OPENID_CONFIG || "{}"));

export const ProviderCookieNames = {
ACCESS_TOKEN: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-access-token" : "",
PROVIDER_PARAMS: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-params" : "",
};

export const requiresUser = !!OIDConfig.CLIENT_ID && !!OIDConfig.CLIENT_SECRET;

export const responseType = OIDConfig.SCOPES.includes("groups") ? "code id_token" : "code";

const sameSite = z
.enum(["lax", "none", "strict"])
.default(dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none")
Expand Down Expand Up @@ -108,7 +121,7 @@ async function getOIDCClient(settings: OIDCSettings): Promise<BaseClient> {
client_id: OIDConfig.CLIENT_ID,
client_secret: OIDConfig.CLIENT_SECRET,
redirect_uris: [settings.redirectURI],
response_types: ["code"],
response_types: ["code", "id_token"],
[custom.clock_tolerance]: OIDConfig.TOLERANCE || undefined,
id_token_signed_response_alg: OIDConfig.ID_TOKEN_SIGNED_RESPONSE_ALG || undefined,
};
Expand All @@ -131,8 +144,13 @@ export async function getOIDCAuthorizationUrl(

return client.authorizationUrl({
scope: OIDConfig.SCOPES,
state: csrfToken,
state: Buffer.from(JSON.stringify({ csrfToken, sessionId: params.sessionId })).toString(
"base64"
),
resource: OIDConfig.RESOURCE || undefined,
response_type: settings.response_type,
response_mode: settings.response_mode,
nonce: settings.nonce,
});
}

Expand All @@ -142,7 +160,11 @@ export async function getOIDCUserData(
iss?: string
): Promise<OIDCUserInfo> {
const client = await getOIDCClient(settings);
const token = await client.callback(settings.redirectURI, { code, iss });
const token = await client.callback(
settings.redirectURI,
{ code, iss },
{ nonce: settings.nonce }
);
const userData = await client.userinfo(token);

return { token, userData };
Expand Down Expand Up @@ -175,3 +197,26 @@ export async function validateAndParseCsrfToken(
}
return null;
}

export async function logout(cookies: Cookies, locals: App.Locals) {
await collections.sessions.deleteOne({ sessionId: locals.sessionId });

const cookie_names = [env.COOKIE_NAME];
if (ProviderCookieNames.ACCESS_TOKEN) {
cookie_names.push(ProviderCookieNames.ACCESS_TOKEN);
}
if (ProviderCookieNames.PROVIDER_PARAMS) {
cookie_names.push(ProviderCookieNames.PROVIDER_PARAMS);
}

for (const cookie_name of cookie_names) {
cookies.delete(cookie_name, {
path: env.APP_BASE || "/",
// So that it works inside the space's iframe
sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none",
secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"),
httpOnly: true,
});
}
redirect(303, `${base}/`);
}
1 change: 1 addition & 0 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ const modelConfig = z.object({
multimodal: z.boolean().default(false),
multimodalAcceptedMimetypes: z.array(z.string()).optional(),
tools: z.boolean().default(false),
allowed_groups: z.array(z.string()).optional(),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
/** Used to enable/disable system prompt usage */
Expand Down
Loading
Loading