diff --git a/ui/src/components/app/section-navigator.tsx b/ui/src/components/app/section-navigator.tsx new file mode 100644 index 00000000..aaa4637b --- /dev/null +++ b/ui/src/components/app/section-navigator.tsx @@ -0,0 +1,60 @@ +import { cn } from "@/lib/utils"; +import { Card, CardContent, CardHeader, CardTitle } from "../ui/card"; +import { useEffect, useState } from "react"; + +export const SectionNavigator = ({ sections }: { sections: { title: string; id: string }[] }) => { + const [activeSection, setActiveSection] = useState<{ title: string; id: string } | null>(null); + + const handleScroll = () => { + // Use reduce instead of find for obtaining the last section that is in view + const currentSection = sections.reduce((result: { title: string; id: string } | null, section) => { + const secElement = document.getElementById(section.id); + if (!secElement) return result; + const rect = secElement.getBoundingClientRect(); + if (rect.top <= window.innerHeight / 2) { + return section; + } + return result; + }, null); + + setActiveSection(currentSection); + }; + + useEffect(() => { + window.addEventListener("scroll", handleScroll); + + // Run the handler to set the initial active section + handleScroll(); + + return () => { + window.removeEventListener("scroll", handleScroll); + }; + }); + + return ( + + + + CONTENTS + + + +
+ +
+
+
+ ); +}; diff --git a/ui/src/components/dictionary/sample.tsx b/ui/src/components/dictionary/sample.tsx index 82ac192b..a5f5f23c 100644 --- a/ui/src/components/dictionary/sample.tsx +++ b/ui/src/components/dictionary/sample.tsx @@ -136,17 +136,19 @@ export const DictionarySampleArea = ({ samples, onSamplesChange, dictionaryName ...featureAct, })) ) - .reduce((acc, featureAct) => { - // Group by featureActIndex - const key = featureAct.featureActIndex.toString(); - if (acc[key]) { - acc[key].push(featureAct); - } else { - acc[key] = [featureAct]; - } - return acc; - }, {} as Record) || - {} + .reduce( + (acc, featureAct) => { + // Group by featureActIndex + const key = featureAct.featureActIndex.toString(); + if (acc[key]) { + acc[key].push(featureAct); + } else { + acc[key] = [featureAct]; + } + return acc; + }, + {} as Record + ) || {} ) .sort( // Sort by sum of featureAct diff --git a/ui/src/components/feature/feature-card.tsx b/ui/src/components/feature/feature-card.tsx index fecc1a5f..fa413694 100644 --- a/ui/src/components/feature/feature-card.tsx +++ b/ui/src/components/feature/feature-card.tsx @@ -88,7 +88,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => { const [showCustomInput, setShowCustomInput] = useState(false); return ( - + @@ -108,7 +108,7 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => { -
+

Activation Histogram

{
{feature.logits && ( -
+

Logits

@@ -180,15 +180,24 @@ export const FeatureCard = ({ feature }: { feature: Feature }) => {
)} -
+
- {feature.sampleGroups.map((sampleGroup) => ( + {feature.sampleGroups.slice(0, feature.sampleGroups.length / 2).map((sampleGroup) => ( {analysisNameMap(sampleGroup.analysisName)} ))} + + {feature.sampleGroups + .slice(feature.sampleGroups.length / 2, feature.sampleGroups.length) + .map((sampleGroup) => ( + + {analysisNameMap(sampleGroup.analysisName)} + + ))} + {feature.sampleGroups.map((sampleGroup) => ( { const [page, setPage] = useState(1); - const maxPage = Math.ceil(sampleGroup.samples.length / 5); + const maxPage = Math.ceil(sampleGroup.samples.length / 10); return (

Max Activation: {Math.max(...sampleGroup.samples[0].featureActs).toFixed(3)}

- {sampleGroup.samples.slice((page - 1) * 5, page * 5).map((sample, i) => ( + {sampleGroup.samples.slice((page - 1) * 10, page * 10).map((sample, i) => ( ))} @@ -69,18 +70,68 @@ export const FeatureActivationSample = ({ sample, sampleName, maxFeatureAct }: F [0] ); + const tokensList = tokens.map((t) => t.featureAct); + const startTrigger = Math.max(tokensList.indexOf(Math.max(...tokensList)) - 100, 0); + const endTrigger = Math.min(tokensList.indexOf(Math.max(...tokensList)) + 10, sample.context.length); + const tokensTrigger = sample.context.slice(startTrigger, endTrigger).map((token, i) => ({ + token, + featureAct: sample.featureActs[startTrigger + i], + })); + + const [tokenGroupsTrigger, __] = tokensTrigger.reduce<[Token[][], Token[]]>( + ([groups, currentGroup], token) => { + const newGroup = [...currentGroup, token]; + try { + decoder.decode(mergeUint8Arrays(newGroup.map((t) => t.token))); + return [[...groups, newGroup], []]; + } catch { + return [groups, newGroup]; + } + }, + [[], []] + ); + + const tokenGroupPositionsTrigger = tokenGroupsTrigger.reduce( + (acc, tokenGroup) => { + const tokenCount = tokenGroup.length; + return [...acc, acc[acc.length - 1] + tokenCount]; + }, + [0] + ); + return (
- {sampleName && {sampleName}: } - {tokenGroups.map((tokens, i) => ( - - ))} + + + +
+ {sampleName && {sampleName}: } + {startTrigger != 0 && ...} + {tokenGroupsTrigger.map((tokens, i) => ( + + ))} + {endTrigger != 0 && ...} +
+
+ + {tokenGroups.map((tokens, i) => ( + + ))} + +
+
); }; diff --git a/ui/src/components/ui/accordion.tsx b/ui/src/components/ui/accordion.tsx index bb60ae56..3068d676 100644 --- a/ui/src/components/ui/accordion.tsx +++ b/ui/src/components/ui/accordion.tsx @@ -22,7 +22,7 @@ const AccordionTrigger = React.forwardRef< svg]:rotate-180", + "flex flex-1 items-center justify-between py-4 font-medium transition-all [&[data-state=open]>svg]:rotate-180", className )} {...props} diff --git a/ui/src/globals.css b/ui/src/globals.css index 9e242755..55ba9ebb 100644 --- a/ui/src/globals.css +++ b/ui/src/globals.css @@ -74,3 +74,7 @@ @apply bg-background text-foreground; } } + +html { + scroll-behavior: smooth; +} diff --git a/ui/src/routes/features/page.tsx b/ui/src/routes/features/page.tsx index f55925f3..8ea03ee0 100644 --- a/ui/src/routes/features/page.tsx +++ b/ui/src/routes/features/page.tsx @@ -1,5 +1,6 @@ import { AppNavbar } from "@/components/app/navbar"; import { FeatureCard } from "@/components/feature/feature-card"; +import { SectionNavigator } from "@/components/app/section-navigator"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; @@ -87,8 +88,23 @@ export const FeaturesPage = () => { // eslint-disable-next-line react-hooks/exhaustive-deps }, [dictionariesState.value]); + const sections = [ + { + title: "Histogram", + id: "Histogram", + }, + { + title: "Logits", + id: "Logits", + }, + { + title: "Top Activation", + id: "Activation", + }, + ].filter((section) => (featureState.value && featureState.value.logits != null) || section.id !== "Logits"); + return ( -
+
@@ -142,6 +158,7 @@ export const FeaturesPage = () => { Show Random Feature
+ {featureState.loading && !loadingRandomFeature && (
Loading Feature #{featureIndex}... @@ -149,7 +166,12 @@ export const FeaturesPage = () => { )} {featureState.loading && loadingRandomFeature &&
Loading Random Living Feature...
} {featureState.error &&
Error: {featureState.error.message}
} - {!featureState.loading && featureState.value && } + {!featureState.loading && featureState.value && ( +
+ + +
+ )}
);