source: trunk/DirectShowSpy/Common.h @ 188

Last change on this file since 188 was 95, checked in by roman, 12 years ago

DirectShowSpy? moved from Assembla

  • Property svn:keywords set to Id
File size: 15.3 KB
Line 
1////////////////////////////////////////////////////////////
2// Copyright (C) Roman Ryltsov, 2008-2011
3// Created by Roman Ryltsov roman@alax.info
4//
5// $Id: Common.h 95 2012-08-17 17:21:13Z roman $
6
7#pragma once
8
9////////////////////////////////////////////////////////////
10// CProcessTokenPrivileges
11
12class CProcessTokenPrivileges
13{
14public:
15        CAccessToken m_ProcessToken;
16        CTokenPrivileges m_OriginalProcessTokenPrivileges;
17        BOOL m_bTakeOwnershipPrivilegeEnabled;
18        BOOL m_bRestorePrivilegeEnabled;
19
20public:
21// CProcessTokenPrivileges
22        CProcessTokenPrivileges() throw() :
23                m_bTakeOwnershipPrivilegeEnabled(FALSE),
24                m_bRestorePrivilegeEnabled(FALSE)
25        {
26        }
27        ~CProcessTokenPrivileges()
28        {
29                if(m_bTakeOwnershipPrivilegeEnabled)
30                {
31                        __E(m_ProcessToken.EnableDisablePrivileges(m_OriginalProcessTokenPrivileges));
32                        m_bRestorePrivilegeEnabled = FALSE;
33                        m_bTakeOwnershipPrivilegeEnabled = FALSE;
34                }
35        }
36        VOID Adjust()
37        {
38                _A(!m_ProcessToken.GetHandle());
39                // NOTE: Enable SE_TAKE_OWNERSHIP_NAME and SE_RESTORE_NAME privileges in order to be able to access registry key permissions
40                __E(m_ProcessToken.GetProcessToken(TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES));
41                __E(m_ProcessToken.EnablePrivilege(SE_TAKE_OWNERSHIP_NAME, &m_OriginalProcessTokenPrivileges));
42                m_bTakeOwnershipPrivilegeEnabled = TRUE;
43                __E(m_ProcessToken.EnablePrivilege(SE_RESTORE_NAME));
44                m_bRestorePrivilegeEnabled = TRUE;
45        }
46};
47
48////////////////////////////////////////////////////////////
49// CClassIdentifierRegKeySecurity
50
51class CClassIdentifierRegKeySecurity
52{
53public:
54        CLSID m_ClassIdentifier;
55        CRegKey m_Key;
56        CSecurityDesc m_OriginalSecurityDescriptor;
57
58public:
59// CClassIdentifierRegKeySecurity
60        static VOID GetRegKeySecurityDescriptor(CRegKey& Key, SECURITY_INFORMATION SecurityInformation, CSecurityDesc& SecurityDescriptor)
61        {
62                DWORD nSecurityDescriptorSize = 0;
63                Key.GetKeySecurity(SecurityInformation, NULL, &nSecurityDescriptorSize);
64                CTempBuffer<SECURITY_DESCRIPTOR> pSecurityDescriptor;
65                _W(pSecurityDescriptor.AllocateBytes(nSecurityDescriptorSize));
66                __C(HRESULT_FROM_WIN32(Key.GetKeySecurity(SecurityInformation, pSecurityDescriptor, &nSecurityDescriptorSize)));
67                SecurityDescriptor = *pSecurityDescriptor;
68        }
69        static CString StringFromSecurityDescriptor(CSecurityDesc& SecurityDescriptor)
70        {
71                CString sSecurityDescriptorString;
72                __E(SecurityDescriptor.ToString(&sSecurityDescriptorString));
73                return sSecurityDescriptorString;
74        }
75        static CString StringFromRegKeySecurityDescriptor(CRegKey& Key, SECURITY_INFORMATION SecurityInformation)
76        {
77                CSecurityDesc SecurityDescriptor;
78                GetRegKeySecurityDescriptor(Key, SecurityInformation, SecurityDescriptor);
79                return StringFromSecurityDescriptor(SecurityDescriptor);
80        }
81        CClassIdentifierRegKeySecurity(const CLSID& ClassIdentifier) throw() :
82                m_ClassIdentifier(ClassIdentifier)
83        {
84        }
85        ~CClassIdentifierRegKeySecurity()
86        {
87                if(m_Key)
88                {
89                        // QUES: Would it be more reliable to leave ownership and revert DACL only in absence of Restore privilege?
90                        __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) m_OriginalSecurityDescriptor))));
91                        _Z4(atlTraceGeneral, 4, _T("Owner & DACL: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION));
92                        __C(HRESULT_FROM_WIN32(m_Key.Close()));
93                }
94        }
95        BOOL Adjust()
96        {
97                if(m_Key)
98                        return FALSE;
99                const CString sKeyName = AtlFormatString(_T("CLSID\\%ls"), _PersistHelper::StringFromIdentifier(m_ClassIdentifier));
100                __C(HRESULT_FROM_WIN32(m_Key.Open(HKEY_CLASSES_ROOT, sKeyName, READ_CONTROL | WRITE_OWNER)));
101                GetRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION | DACL_SECURITY_INFORMATION, m_OriginalSecurityDescriptor);
102                // NOTE:
103                //   Windows 5.1: CLSID {E436EBB2-524F-11CE-9F53-0020AF0BA770} Key Owner: O:AB (Administrators)
104                //   Windows 6.0: CLSID {E436EBB2-524F-11CE-9F53-0020AF0BA770} Key Owner: O:S-1-5-80-956008885-3418522649-1831038044-1853292631-2271478464 (TrustedInstaller)
105                _Z4(atlTraceGeneral, 4, _T("ClassIdentifier %ls, m_OriginalSecurityDescriptor %s\n"), _PersistHelper::StringFromIdentifier(m_ClassIdentifier), StringFromSecurityDescriptor(m_OriginalSecurityDescriptor));
106                // NOTE: Take ownership of the key to Administrators in order to be able to update key DACL
107                // QUES: Would it be more reliable to take ownership to self in absence of Restore privilege?
108                CSecurityDesc AdministratorsOwnerSecurityDescriptor;
109                AdministratorsOwnerSecurityDescriptor.SetOwner(Sids::Admins());
110                __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(OWNER_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) AdministratorsOwnerSecurityDescriptor))));
111                _Z4(atlTraceGeneral, 4, _T("Owner: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, OWNER_SECURITY_INFORMATION));
112                // NOTE: Reopen the key to obtain given privileges
113                __C(HRESULT_FROM_WIN32(m_Key.Close()));
114                __C(HRESULT_FROM_WIN32(m_Key.Open(HKEY_CLASSES_ROOT, sKeyName, READ_CONTROL | WRITE_DAC | WRITE_OWNER)));
115                // NOTE: Adjust key DACL in order to make the key writable
116                CSecurityDesc AccessListSecurityDescriptor = m_OriginalSecurityDescriptor;
117                CDacl AccessList;
118                __E(AccessListSecurityDescriptor.GetDacl(&AccessList));
119                _W(AccessList.AddAllowedAce(Sids::Admins(), GENERIC_ALL, CONTAINER_INHERIT_ACE));
120                AccessListSecurityDescriptor.SetDacl(AccessList);
121                __C(HRESULT_FROM_WIN32(m_Key.SetKeySecurity(DACL_SECURITY_INFORMATION, const_cast<SECURITY_DESCRIPTOR*>((const SECURITY_DESCRIPTOR*) AccessListSecurityDescriptor))));
122                _Z4(atlTraceGeneral, 4, _T("DACL: %s\n"), StringFromRegKeySecurityDescriptor(m_Key, DACL_SECURITY_INFORMATION));
123                return TRUE;
124        }
125};
126
127////////////////////////////////////////////////////////////
128// TreatAsUpdateRegistryFromResource
129
130template <typename T>
131inline VOID TreatAsUpdateRegistryFromResource(const CLSID& TreatAsClassIdentifier, BOOL bRegister)
132{
133        _Z2(atlTraceRegistrar, 2, _T("TreatAsClassIdentifier %ls, bRegister %d\n"), _PersistHelper::StringFromIdentifier(TreatAsClassIdentifier), bRegister);
134        // NOTE: Registration is much more sophisticated starting Vista operating system
135        const ULONG nOsVersion = GetOsVersion();
136        _Z4(atlTraceRegistrar, 4, _T("nOsVersion 0x%08x\n"), nOsVersion);
137        CProcessTokenPrivileges ProcessTokenPrivileges;
138        if(nOsVersion >= 0x060000) // Vista+
139                _ATLTRY
140                {
141                        ProcessTokenPrivileges.Adjust();
142                }
143                _ATLCATCHALL()
144                {
145                        _Z_EXCEPTION();
146                }
147        CLSID CurrentTreatAsClassIdentifier = CLSID_NULL;
148        const HRESULT nCoGetTreatAsClassResult = CoGetTreatAsClass(TreatAsClassIdentifier, &CurrentTreatAsClassIdentifier);
149        __C(nCoGetTreatAsClassResult);
150        _Z4(atlTraceRegistrar, 4, _T("bRegister %d, nCoGetTreatAsClassResult 0x%08x, CurrentTreatAsClassIdentifier %ls\n"), bRegister, nCoGetTreatAsClassResult, _PersistHelper::StringFromIdentifier(CurrentTreatAsClassIdentifier));
151        __D(!bRegister || nCoGetTreatAsClassResult != S_OK || CurrentTreatAsClassIdentifier == T::GetObjectCLSID(), E_UNNAMED);
152        CClassIdentifierRegKeySecurity ClassIdentifierRegKeySecurity(TreatAsClassIdentifier);
153        if(!bRegister && nCoGetTreatAsClassResult == S_OK)
154        {
155                if(nOsVersion >= 0x060000) // Vista+
156                        ClassIdentifierRegKeySecurity.Adjust();
157                __C(CoTreatAsClass(TreatAsClassIdentifier, CLSID_NULL));
158        }
159        _A(_pAtlModule);
160        UpdateRegistryFromResource<T>(bRegister);
161        if(bRegister)
162        {
163                if(nOsVersion >= 0x060000) // Vista+
164                        ClassIdentifierRegKeySecurity.Adjust();
165#if _DEVELOPMENT
166                const HRESULT nCoTreatAsClassResult = CoTreatAsClass(TreatAsClassIdentifier, T::GetObjectCLSID());
167                _Z2(atlTraceRegistrar, SUCCEEDED(nCoTreatAsClassResult) ? 4 : 2, _T("nCoTreatAsClassResult 0x%08x\n"), nCoTreatAsClassResult);
168                __C(nCoTreatAsClassResult);
169                const HRESULT nCoGetTreatAsClassResult = CoGetTreatAsClass(TreatAsClassIdentifier, &CurrentTreatAsClassIdentifier);
170                _Z4(atlTraceRegistrar, 4, _T("nCoGetTreatAsClassResult 0x%08x, CurrentTreatAsClassIdentifier %ls\n"), nCoGetTreatAsClassResult, _PersistHelper::StringFromIdentifier(CurrentTreatAsClassIdentifier));
171                _A(CurrentTreatAsClassIdentifier == T::GetObjectCLSID());
172#else
173                __C(CoTreatAsClass(TreatAsClassIdentifier, T::GetObjectCLSID()));
174#endif // _DEVELOPMENT
175        }
176}
177
178////////////////////////////////////////////////////////////
179// CBlackListAwareComCreatorT
180
181template <typename T, const CLSID* t_pClassIdentifier>
182class ATL_NO_VTABLE CTransparentCoClassT
183{
184public:
185// CTransparentCoClassT
186        static HINSTANCE CoLoadOriginalLibrary()
187        {
188                const HINSTANCE hModule = CoLoadLibrary(const_cast<LPOLESTR>((LPCOLESTR) CT2COLE(T::GetOriginalLibraryName())), TRUE);
189                __E(hModule);
190                return hModule;
191        }
192        static CComPtr<IUnknown> CoCreateOriginalInstance(HINSTANCE hModule, IUnknown* pControllingUnknown = NULL)
193        {
194                typedef HRESULT (WINAPI *DLLGETCLASSOBJECT)(REFCLSID, REFIID, VOID**);
195                DLLGETCLASSOBJECT DllGetClassObject = (DLLGETCLASSOBJECT) GetProcAddress(hModule, "DllGetClassObject");
196                __E(DllGetClassObject);
197                CComPtr<IClassFactory> pClassFactory;
198                __C(DllGetClassObject(*t_pClassIdentifier, __uuidof(IClassFactory), (VOID**) &pClassFactory));
199                _A(pClassFactory);
200                CComPtr<IUnknown> pUnknown;
201                __C(pClassFactory->CreateInstance(pControllingUnknown, __uuidof(IUnknown), (VOID**) &pUnknown));
202                return pUnknown;
203        }
204        static CComPtr<IUnknown> CoCreateOriginalInstance(IUnknown* pControllingUnknown = NULL)
205        {
206                CComPtr<IUnknown> pUnknown;
207                const HINSTANCE hModule = CoLoadOriginalLibrary();
208                _ATLTRY
209                {
210                        pUnknown = CoCreateOriginalInstance(hModule, pControllingUnknown);
211                }
212                _ATLCATCHALL()
213                {
214                        CoFreeLibrary(hModule);
215                        _ATLRETHROW;
216                }
217                CoFreeLibrary(hModule);
218                return pUnknown;
219        }
220};
221
222////////////////////////////////////////////////////////////
223// CBlackListAwareComCreatorT
224
225template <typename _ObjectClass, typename _Class, LPCTSTR* t_ppszName>
226class CBlackListAwareComCreatorT :
227        public CComCreator<_ObjectClass>
228{
229public:
230// CBlackListAwareComCreatorT
231        static HRESULT WINAPI CreateInstance(VOID* pvControllingUnknown, REFIID InterfaceIdentifier, VOID** ppvObject) throw()
232        {
233                _A(ppvObject);
234                *ppvObject = NULL;
235                #pragma region Check Black List
236                static INT g_nBlackListed = 0; // 0 Unknown, 1 No, 2 Yes
237                {
238                        _A(_pAtlModule);
239                        CComCritSecLock<CComCriticalSection> Lock(_pAtlModule->m_csStaticDataInitAndTypeInfo);
240                        if(!g_nBlackListed)
241                        {
242                                TCHAR pszPath[MAX_PATH] = { 0 };
243                                _W(GetModuleFileName(NULL, pszPath, DIM(pszPath)));
244                                LPTSTR pszFileName = FindFileName(pszPath);
245                                _A(pszFileName);
246                                RemoveExtension(pszFileName);
247                                const CString sBlackList = _RegKeyHelper::QueryStringValue(HKEY_LOCAL_MACHINE, REGISTRY_ROOT, AtlFormatString(_T("%s Black List"), *t_ppszName));
248                                CRoArrayT<CString> BlackListArray;
249                                _StringHelper::GetCommaSeparatedItems(sBlackList, BlackListArray);
250                                BOOL bFound = FALSE;
251                                for(SIZE_T nIndex = 0; nIndex < BlackListArray.GetCount(); nIndex++)
252                                {
253                                        CPath sFileName = (LPCTSTR) BlackListArray[nIndex];
254                                        sFileName.RemoveExtension();
255                                        if(_tcsicmp(sFileName, pszFileName) == 0)
256                                        {
257                                                _Z2(atlTraceCOM, 2, _T("Will instantiate original COM class, sFileName \"%s\"\n"), sFileName);
258                                                bFound = TRUE;
259                                                break;
260                                        }
261                                }
262                                g_nBlackListed = bFound ? 2 : 1;
263                        }
264                }
265                #pragma endregion
266                if(g_nBlackListed == 2)
267                        #pragma region CoCreateInstance Original Class
268                        _ATLTRY
269                        {
270                                _A(_pAtlModule);
271                                const LONG nLockResult = _pAtlModule->Lock();
272                                _ATLTRY
273                                {
274                                        CComPtr<IUnknown> pUnknown = _Class::CoCreateOriginalInstance((IUnknown*) pvControllingUnknown);
275                                        if(InterfaceIdentifier == __uuidof(IUnknown))
276                                                *ppvObject = pUnknown.Detach();
277                                        else
278                                                __C(pUnknown->QueryInterface(InterfaceIdentifier, ppvObject));
279                                }
280                                _ATLCATCHALL()
281                                {
282                                        _pAtlModule->Unlock();
283                                        _ATLRETHROW;
284                                }
285                                const LONG nUnlockResult = _pAtlModule->Unlock();
286                                _Z6(atlTraceGeneral, 6, _T("nLockResult %d, nUnlockResult %d\n"), nLockResult, nUnlockResult);
287                                return S_OK;
288                        }
289                        _ATLCATCH(Exception)
290                        {
291                                _A(FAILED(Exception));
292                                _C(Exception);
293                        }
294                        #pragma endregion
295                return __super::CreateInstance(pvControllingUnknown, InterfaceIdentifier, ppvObject);
296        }
297};
298
299////////////////////////////////////////////////////////////
300// CHookHostT
301
302template <typename T, typename IHook, LPCTSTR* t_ppszHookName>
303class CHookHostT
304{
305public:
306
307        ////////////////////////////////////////////////////////
308        // CHookArray
309
310        class CHookArray :
311                public CRoArrayT<CComPtr<IHook> >
312        {
313        public:
314        // CHookArray
315        };
316
317private:
318        mutable CRoCriticalSection m_HookCriticalSection;
319        BOOL m_bHookArrayInitialized;
320        CHookArray m_HookArray;
321
322        VOID InitializeHookArray()
323        {
324                _A(m_HookArray.IsEmpty());
325                CRoListT<CLSID> ClassIdentifierList;
326                static const HKEY g_phParentKeys[] = { HKEY_LOCAL_MACHINE, HKEY_CURRENT_USER };
327                static const LPCTSTR g_ppszKeyNameFormats[] = { _T("SOFTWARE\\Classes\\"), _T("Software\\Classes\\") };
328                for(SIZE_T nKeyIndex = 0; nKeyIndex < DIM(g_phParentKeys); nKeyIndex++)
329                {
330                        const CString sKeyName = AtlFormatString(_T("%sCLSID\\%ls\\Hooks\\%s"), g_ppszKeyNameFormats[nKeyIndex], _PersistHelper::StringFromIdentifier(T::GetObjectCLSID()), *t_ppszHookName);
331                        CRegKey Key;
332                        if(FAILED(HRESULT_FROM_WIN32(Key.Open(g_phParentKeys[nKeyIndex], sKeyName, KEY_READ))))
333                                continue;
334                        for(DWORD nIndex = 0; ; nIndex++)
335                        {
336                                DWORD nNameLength = 0;
337                                RegEnumKeyEx(Key, nIndex, NULL, &nNameLength, NULL, NULL, NULL, NULL);
338                                nNameLength = max(2 * nNameLength, 256);
339                                CTempBuffer<TCHAR, 4096> pszName(nNameLength);
340                                const HRESULT nRegEnumKeyResult = HRESULT_FROM_WIN32(RegEnumKeyEx(Key, nIndex, pszName, &nNameLength, NULL, NULL, NULL, NULL));
341                                if(FAILED(nRegEnumKeyResult))
342                                {
343                                        __D(nRegEnumKeyResult == HRESULT_FROM_WIN32(ERROR_NO_MORE_ITEMS), nRegEnumKeyResult);
344                                        break;
345                                }
346                                _ATLTRY
347                                {
348                                        const CLSID ClassIdentifier = _PersistHelper::ClassIdentifierFromString(CT2CW(pszName));
349                                        _Z4(atlTraceGeneral, 4, _T("ClassIdentifier %ls\n"), _PersistHelper::StringFromIdentifier(ClassIdentifier));
350                                        __D(ClassIdentifier != CLSID_NULL, E_UNNAMED);
351                                        if(ClassIdentifierList.Find(ClassIdentifier))
352                                                continue;
353                                        _W(ClassIdentifierList.AddTail(ClassIdentifier));
354                                        CComPtr<IHook> pHook;
355                                        __C(pHook.CoCreateInstance(ClassIdentifier));
356                                        _W(m_HookArray.Add(pHook) >= 0);
357                                }
358                                _ATLCATCHALL()
359                                {
360                                        _Z_EXCEPTION();
361                                }
362                        }
363                }
364        }
365
366public:
367// CHookHostT
368        CHookHostT() throw() :
369                m_bHookArrayInitialized(FALSE)
370        {
371        }
372        SIZE_T GetHookArray(CHookArray& HookArray)
373        {
374                _A(HookArray.IsEmpty());
375                CRoCriticalSectionLock HookLock(m_HookCriticalSection);
376                if(!m_bHookArrayInitialized)
377                        _ATLTRY
378                        {
379                                m_bHookArrayInitialized = TRUE;
380                                InitializeHookArray();
381                        }
382                        _ATLCATCHALL()
383                        {
384                                _Z_EXCEPTION();
385                        }
386                HookArray.Append(m_HookArray);
387                return HookArray.GetCount();
388        }
389};
390
391#define HOOK_PROLOG(Base) \
392        _ATLTRY \
393        { \
394                Base::CHookArray HookArray; \
395                if(Base::GetHookArray(HookArray)) \
396                { \
397                        T* pT = static_cast<T*>(this); \
398                        for(SIZE_T nIndex = 0; nIndex < HookArray.GetCount(); nIndex++) \
399                        { \
400                                BOOL bDefault = TRUE; \
401                                const HRESULT nResult = HookArray[nIndex]->
402                                               
403#define HOOK_EPILOG() \
404                                if(!bDefault) \
405                                        return nResult; \
406                                _A(SUCCEEDED(nResult)); \
407                        } \
408                } \
409        } \
410        _ATLCATCHALL() \
411        { \
412                _Z_EXCEPTION(); \
413        }
414
Note: See TracBrowser for help on using the repository browser.